Welcome to My Notes Site
Welcome to a curated collection of notes and resources spanning a wide array of topics, from programming languages and data structures to machine learning and cloud computing. This is my personal repository of knowledge, where I document my learnings and insights across various subjects.
Table of Contents
- Git
- Programming Languages
- Linux
- Android
- Data Structures
- Algorithms
- Security
- Computer Networking
- Wifi
- Machine Learning
- AI
- Cloud
- Finance
- Tools
- Embedded
Feel free to explore the sections and delve into the topics that pique your interest. Happy learning!
Note: This is an evolving project, and I will continue to add more content over time. The contents may be rearranged or updated as needed.
Git Version Control
Overview
Git is a distributed version control system designed to handle everything from small to very large projects with speed and efficiency. Created by Linus Torvalds in 2005, Git has become the de facto standard for version control in software development.
What is Version Control?
Version control is a system that records changes to files over time so that you can recall specific versions later. It allows you to:
- Track changes to your code
- Collaborate with other developers
- Revert to previous versions
- Create branches for experimental features
- Merge changes from multiple sources
- Maintain a complete history of your project
Why Git?
- Distributed: Every developer has a full copy of the repository
- Fast: Most operations are local
- Branching: Lightweight and powerful branching model
- Data Integrity: Cryptographic hash (SHA-1) ensures data integrity
- Staging Area: Review changes before committing
- Open Source: Free and widely supported
Git Basics
The Three States
Git has three main states for your files:
- Modified: Changed but not committed
- Staged: Marked for next commit
- Committed: Safely stored in local database
Working Directory -> Staging Area -> Git Repository
(edit) (stage) (commit)
Git Workflow
# 1. Make changes in working directory
echo "Hello World" > file.txt
# 2. Stage changes
git add file.txt
# 3. Commit changes
git commit -m "Add hello world file"
# 4. Push to remote repository
git push origin main
Installation and Setup
Installation
# Linux (Debian/Ubuntu)
sudo apt-get update
sudo apt-get install git
# Linux (Fedora)
sudo dnf install git
# macOS (Homebrew)
brew install git
# Windows
# Download from https://git-scm.com/download/win
# Verify installation
git --version
Initial Configuration
# Set user name
git config --global user.name "Your Name"
# Set email
git config --global user.email "your.email@example.com"
# Set default editor
git config --global core.editor "vim"
# Set default branch name
git config --global init.defaultBranch main
# Enable color output
git config --global color.ui auto
# View all settings
git config --list
# View specific setting
git config user.name
# Edit config file directly
git config --global --edit
Basic Commands
Creating Repositories
# Initialize new repository
git init
# Clone existing repository
git clone https://github.com/user/repo.git
# Clone to specific directory
git clone https://github.com/user/repo.git my-project
# Clone specific branch
git clone -b develop https://github.com/user/repo.git
Making Changes
# Check status
git status
# Add file to staging
git add file.txt
# Add all files
git add .
# Add all files with specific extension
git add *.js
# Interactive staging
git add -p
# Commit staged changes
git commit -m "Commit message"
# Commit with detailed message
git commit
# Stage and commit in one step
git commit -am "Message"
# Amend last commit
git commit --amend
# Amend without changing message
git commit --amend --no-edit
Viewing History
# View commit history
git log
# Compact log
git log --oneline
# Graph view
git log --graph --oneline --all
# Limit number of commits
git log -n 5
# Show commits by author
git log --author="John"
# Show commits in date range
git log --since="2 weeks ago"
git log --until="2024-01-01"
# Show file changes
git log --stat
# Show detailed changes
git log -p
# Search commit messages
git log --grep="fix"
# Show commits affecting specific file
git log -- file.txt
Viewing Changes
# Show unstaged changes
git diff
# Show staged changes
git diff --staged
# Show changes in specific file
git diff file.txt
# Compare branches
git diff main..feature
# Compare commits
git diff commit1 commit2
# Word-level diff
git diff --word-diff
Branching and Merging
Branches
# List branches
git branch
# List all branches (including remote)
git branch -a
# Create new branch
git branch feature-name
# Switch to branch
git checkout feature-name
# Create and switch in one command
git checkout -b feature-name
# Modern syntax (Git 2.23+)
git switch feature-name
git switch -c feature-name
# Delete branch
git branch -d feature-name
# Force delete unmerged branch
git branch -D feature-name
# Rename current branch
git branch -m new-name
# Rename specific branch
git branch -m old-name new-name
Merging
# Merge branch into current branch
git merge feature-name
# Merge with commit message
git merge feature-name -m "Merge feature"
# Merge without fast-forward
git merge --no-ff feature-name
# Abort merge
git merge --abort
# Continue merge after resolving conflicts
git merge --continue
Handling Merge Conflicts
# When merge conflict occurs:
# 1. Check conflicted files
git status
# 2. Open files and resolve conflicts
# Look for markers: <<<<<<<, =======, >>>>>>>
# 3. After resolving, stage files
git add resolved-file.txt
# 4. Complete merge
git commit
# Or use merge tool
git mergetool
Rebasing
# Rebase current branch onto main
git rebase main
# Interactive rebase (last 3 commits)
git rebase -i HEAD~3
# Continue after resolving conflicts
git rebase --continue
# Skip current commit
git rebase --skip
# Abort rebase
git rebase --abort
# Rebase options in interactive mode:
# pick = use commit
# reword = use commit, but edit message
# edit = use commit, but stop for amending
# squash = merge with previous commit
# drop = remove commit
Remote Repositories
Working with Remotes
# List remotes
git remote
# List remotes with URLs
git remote -v
# Add remote
git remote add origin https://github.com/user/repo.git
# Change remote URL
git remote set-url origin https://github.com/user/new-repo.git
# Remove remote
git remote remove origin
# Rename remote
git remote rename origin upstream
# Show remote info
git remote show origin
Fetching and Pulling
# Fetch from remote (doesn't merge)
git fetch origin
# Fetch all remotes
git fetch --all
# Pull (fetch + merge)
git pull origin main
# Pull with rebase
git pull --rebase origin main
# Pull specific branch
git pull origin feature-branch
Pushing
# Push to remote
git push origin main
# Push and set upstream
git push -u origin main
# Push all branches
git push --all origin
# Push tags
git push --tags
# Force push (dangerous!)
git push --force origin main
# Safer force push
git push --force-with-lease origin main
# Delete remote branch
git push origin --delete branch-name
Undoing Changes
Working Directory
# Discard changes in file
git checkout -- file.txt
# Discard all changes
git checkout -- .
# Modern syntax
git restore file.txt
git restore .
# Remove untracked files
git clean -f
# Remove untracked files and directories
git clean -fd
# Preview what will be removed
git clean -n
Staging Area
# Unstage file
git reset HEAD file.txt
# Unstage all files
git reset HEAD
# Modern syntax
git restore --staged file.txt
Commits
# Undo last commit (keep changes)
git reset --soft HEAD~1
# Undo last commit (discard changes)
git reset --hard HEAD~1
# Undo multiple commits
git reset --hard HEAD~3
# Reset to specific commit
git reset --hard commit-hash
# Create new commit that undoes changes
git revert commit-hash
# Revert multiple commits
git revert commit1..commit3
Advanced Features
Stashing
# Stash current changes
git stash
# Stash with message
git stash save "Work in progress"
# List stashes
git stash list
# Apply last stash
git stash apply
# Apply specific stash
git stash apply stash@{2}
# Apply and remove stash
git stash pop
# Create branch from stash
git stash branch feature-name
# Drop stash
git stash drop stash@{0}
# Clear all stashes
git stash clear
# Stash including untracked files
git stash -u
Tags
# List tags
git tag
# Create lightweight tag
git tag v1.0.0
# Create annotated tag
git tag -a v1.0.0 -m "Version 1.0.0"
# Tag specific commit
git tag v1.0.0 commit-hash
# Push tag to remote
git push origin v1.0.0
# Push all tags
git push --tags
# Delete local tag
git tag -d v1.0.0
# Delete remote tag
git push origin --delete v1.0.0
# Checkout tag
git checkout v1.0.0
Cherry-Pick
# Apply specific commit to current branch
git cherry-pick commit-hash
# Cherry-pick multiple commits
git cherry-pick commit1 commit2
# Cherry-pick without committing
git cherry-pick -n commit-hash
# Abort cherry-pick
git cherry-pick --abort
Bisect
# Start bisect session
git bisect start
# Mark current commit as bad
git bisect bad
# Mark known good commit
git bisect good commit-hash
# Git will checkout middle commit
# Test and mark as good or bad
git bisect good # or git bisect bad
# Continue until bug is found
# End bisect session
git bisect reset
Git Workflows
Feature Branch Workflow
# 1. Create feature branch
git checkout -b feature/new-feature
# 2. Make changes and commit
git add .
git commit -m "Implement new feature"
# 3. Push to remote
git push -u origin feature/new-feature
# 4. Create pull request (on GitHub/GitLab)
# 5. After review, merge via web interface
# 6. Update local main branch
git checkout main
git pull origin main
# 7. Delete feature branch
git branch -d feature/new-feature
git push origin --delete feature/new-feature
Gitflow Workflow
# Main branches: main (production), develop (integration)
# Start new feature
git checkout -b feature/feature-name develop
# Finish feature
git checkout develop
git merge --no-ff feature/feature-name
git branch -d feature/feature-name
# Start release
git checkout -b release/1.0.0 develop
# Finish release
git checkout main
git merge --no-ff release/1.0.0
git tag -a v1.0.0
git checkout develop
git merge --no-ff release/1.0.0
git branch -d release/1.0.0
# Hotfix
git checkout -b hotfix/fix-bug main
git checkout main
git merge --no-ff hotfix/fix-bug
git tag -a v1.0.1
git checkout develop
git merge --no-ff hotfix/fix-bug
git branch -d hotfix/fix-bug
Fork and Pull Request Workflow
# 1. Fork repository on GitHub
# 2. Clone your fork
git clone https://github.com/your-username/repo.git
cd repo
# 3. Add upstream remote
git remote add upstream https://github.com/original-owner/repo.git
git remote -v
# 4. Create feature branch
git checkout -b feature/my-feature
# 5. Make changes and commit
git add .
git commit -m "Add new feature"
# 6. Keep your fork updated
git fetch upstream
git checkout main
git merge upstream/main
git push origin main
# 7. Rebase your feature branch (optional but recommended)
git checkout feature/my-feature
git rebase main
# 8. Push to your fork
git push origin feature/my-feature
# 9. Create Pull Request on GitHub
# - Navigate to original repository
# - Click "New Pull Request"
# - Select your fork and branch
# 10. After PR is merged, update and cleanup
git checkout main
git pull upstream main
git push origin main
git branch -d feature/my-feature
git push origin --delete feature/my-feature
Trunk-Based Development
# Work directly on main branch with short-lived feature branches
# 1. Create short-lived feature branch
git checkout -b feature/quick-fix
# 2. Make small, incremental changes
git add .
git commit -m "Implement part 1 of feature"
# 3. Keep branch up to date with main (multiple times per day)
git checkout main
git pull origin main
git checkout feature/quick-fix
git rebase main
# 4. Merge back to main quickly (within hours or 1-2 days)
git checkout main
git merge --no-ff feature/quick-fix
git push origin main
# 5. Delete feature branch
git branch -d feature/quick-fix
# Alternative: Direct commits to main (for very small changes)
git checkout main
git pull origin main
# Make small change
git add .
git commit -m "Fix typo"
git push origin main
Release Branch Workflow
# Create release branch from main
git checkout -b release/v2.0.0 main
# Make release-specific changes (version bumps, changelog, etc.)
git add .
git commit -m "Prepare release v2.0.0"
# Test the release branch thoroughly
# Fix any bugs found
git add .
git commit -m "Fix release bug"
# Merge to main and tag
git checkout main
git merge --no-ff release/v2.0.0
git tag -a v2.0.0 -m "Release version 2.0.0"
git push origin main
git push origin v2.0.0
# Merge release changes back to develop (if using Gitflow)
git checkout develop
git merge --no-ff release/v2.0.0
# Delete release branch
git branch -d release/v2.0.0
Daily Workflow Patterns
Start of Day
# Update your local repository
git checkout main
git pull origin main
# Check what you were working on
git status
git log --oneline -5
# Resume work on feature branch
git checkout feature/my-feature
git rebase main
During Development
# Check status frequently
git status
# View changes before staging
git diff
# Stage changes selectively
git add -p # Interactive staging
# Commit with meaningful message
git commit -m "feat: Add user authentication
Implement JWT-based authentication system with:
- Login endpoint
- Token validation middleware
- Logout functionality
Refs #123"
# Push to remote frequently
git push origin feature/my-feature
# Save work in progress without committing
git stash save "WIP: working on login form"
Before Creating Pull Request
# Make sure branch is up to date
git checkout main
git pull origin main
git checkout feature/my-feature
git rebase main
# Clean up commit history (if needed)
git rebase -i HEAD~5
# Squash, reword, or reorder commits
# Run tests
npm test # or your test command
# Push updated branch
git push --force-with-lease origin feature/my-feature
# Create Pull Request on GitHub
After Pull Request Review
# Address review comments
git add .
git commit -m "Address PR feedback"
# Or amend last commit
git add .
git commit --amend --no-edit
# Force push (your PR branch)
git push --force-with-lease origin feature/my-feature
End of Day
# Commit work in progress
git add .
git commit -m "WIP: partial implementation"
# Or stash if not ready to commit
git stash save "WIP: end of day $(date)"
# Push to remote as backup
git push origin feature/my-feature
Working with Multiple Features
# Save current work
git stash
# Switch to different feature
git checkout feature/other-feature
# Work on it...
git add .
git commit -m "Update feature"
# Switch back to original feature
git checkout feature/my-feature
git stash pop
Common Workflow Scenarios
Fixing a Bug in Production
# 1. Create hotfix branch from main
git checkout main
git pull origin main
git checkout -b hotfix/critical-bug
# 2. Fix the bug
git add .
git commit -m "fix: Resolve critical authentication bug
Fix issue where users couldn't login after password reset.
Fixes #456"
# 3. Test thoroughly
npm test
# 4. Merge to main
git checkout main
git merge --no-ff hotfix/critical-bug
git tag -a v1.0.1 -m "Hotfix release 1.0.1"
# 5. Push to production
git push origin main
git push origin v1.0.1
# 6. Merge back to develop
git checkout develop
git merge --no-ff hotfix/critical-bug
# 7. Cleanup
git branch -d hotfix/critical-bug
Syncing Fork with Upstream
# Add upstream if not already added
git remote add upstream https://github.com/original/repo.git
# Fetch upstream changes
git fetch upstream
# Merge upstream changes to main
git checkout main
git merge upstream/main
# Push to your fork
git push origin main
# Update your feature branch
git checkout feature/my-feature
git rebase main
Collaborating on a Branch
# Person A creates branch and pushes
git checkout -b feature/shared-feature
git add .
git commit -m "Initial implementation"
git push -u origin feature/shared-feature
# Person B clones and contributes
git fetch origin
git checkout feature/shared-feature
git add .
git commit -m "Add tests"
git push origin feature/shared-feature
# Person A pulls updates
git checkout feature/shared-feature
git pull origin feature/shared-feature
Recovering from Mistakes
# Undo last commit but keep changes
git reset --soft HEAD~1
# Discard all local changes
git reset --hard HEAD
# Recover deleted branch
git reflog
git checkout -b recovered-branch <commit-hash>
# Undo force push (if reflog available)
git reflog
git reset --hard HEAD@{n}
git push --force-with-lease
# Revert a merged PR
git revert -m 1 <merge-commit-hash>
git push origin main
Working with Large Files
# Install Git LFS
git lfs install
# Track large files
git lfs track "*.psd"
git lfs track "*.mp4"
git lfs track "datasets/*"
# Add .gitattributes
git add .gitattributes
# Add and commit large files
git add large-file.psd
git commit -m "Add design file"
git push origin main
Maintaining Clean History
# Squash commits before merging
git checkout feature/my-feature
git rebase -i main
# In editor, change "pick" to "squash" for commits to combine
# Rewrite commit message
git commit --amend
# Force push (only on feature branches!)
git push --force-with-lease origin feature/my-feature
Best Practices
Commit Messages
# Good commit message structure:
# <type>: <subject>
#
# <body>
#
# <footer>
# Example:
git commit -m "feat: Add user authentication
Implement JWT-based authentication system with login and logout endpoints.
Uses bcrypt for password hashing.
Closes #123"
# Common types:
# feat: New feature
# fix: Bug fix
# docs: Documentation changes
# style: Formatting, missing semicolons, etc.
# refactor: Code restructuring
# test: Adding tests
# chore: Maintenance tasks
General Best Practices
- Commit Often: Make small, logical commits
- Write Clear Messages: Explain what and why
- Use Branches: Keep main stable
- Pull Before Push: Stay synchronized
- Review Before Commit: Check what you're committing
- Don't Commit Secrets: Use .gitignore for sensitive files
- Keep History Clean: Use rebase for feature branches
- Tag Releases: Mark important versions
- Backup Remote: Always have a remote backup
- Learn to Revert: Know how to undo mistakes
.gitignore
# Create .gitignore file
cat > .gitignore << 'EOL'
# Dependencies
node_modules/
vendor/
# Environment files
.env
.env.local
# Build outputs
dist/
build/
*.log
# IDE files
.vscode/
.idea/
*.swp
# OS files
.DS_Store
Thumbs.db
# Compiled files
*.pyc
*.class
*.o
EOL
# Global gitignore
git config --global core.excludesfile ~/.gitignore_global
Troubleshooting
Common Issues
# Undo last commit but keep changes
git reset --soft HEAD~1
# Fix wrong commit message
git commit --amend -m "Correct message"
# Recover deleted branch
git reflog
git checkout -b recovered-branch commit-hash
# Resolve "detached HEAD"
git checkout main
# Remove file from Git but keep locally
git rm --cached file.txt
# Update .gitignore for already tracked files
git rm -r --cached .
git add .
git commit -m "Update .gitignore"
# Find commit that introduced bug
git bisect start
git bisect bad
git bisect good commit-hash
Performance
# Clean up repository
git gc
# Aggressive cleanup
git gc --aggressive
# Prune unreachable objects
git prune
# Show repository size
git count-objects -vH
Git Aliases
# Create useful aliases
git config --global alias.co checkout
git config --global alias.br branch
git config --global alias.ci commit
git config --global alias.st status
git config --global alias.unstage 'reset HEAD --'
git config --global alias.last 'log -1 HEAD'
git config --global alias.visual 'log --graph --oneline --all'
git config --global alias.amend 'commit --amend --no-edit'
# Use aliases
git co main
git ci -m "Message"
git visual
Git Internals
Want to understand how Git works under the hood? The Git Internals guide provides an in-depth exploration of:
- Object Model: Blobs, trees, commits, tags, and SHA-1 hashing
- File Tracking: The index, staging area, and file states
- Refs and HEAD: References, symbolic refs, and detached HEAD
- Plumbing Commands: Low-level commands that power Git
- Pack Files: Storage optimization and delta compression
- Reflog: Recovery and time-travel debugging
- Remote Tracking: How fetch, pull, and push work internally
Understanding internals helps you debug issues, recover from mistakes, and master advanced Git operations.
Integration with GitHub
GitHub adds collaboration features on top of Git. See the dedicated GitHub guide for:
- Pull requests
- Issues
- Actions (CI/CD)
- Pages
- Wikis
- Organizations and teams
Available Resources
- Git Cheat Sheet - Quick reference guide
- Git Commands - Comprehensive command list
- Git Internals - Deep dive into Git's internal architecture, plumbing commands, refs, object model, and tracking
- GitHub Guide - GitHub-specific features
Learning Resources
Documentation
- Official Git Documentation
- Pro Git Book (free online)
- Git Reference
Interactive Tutorials
Visualizations
Quick Reference
Daily Commands
git status # Check status
git add . # Stage all changes
git commit -m "message" # Commit changes
git pull # Update from remote
git push # Push to remote
git log --oneline # View history
Branching
git branch # List branches
git checkout -b feature # Create and switch
git merge feature # Merge branch
git branch -d feature # Delete branch
Undoing
git reset --soft HEAD~1 # Undo commit, keep changes
git restore file.txt # Discard file changes
git revert commit-hash # Create revert commit
git stash # Save temporary changes
Remote
git remote -v # List remotes
git fetch origin # Download changes
git pull origin main # Fetch and merge
git push origin main # Upload changes
Next Steps
- Practice basic commands: add, commit, push, pull
- Learn branching and merging
- Master undoing changes safely
- Explore advanced features: rebase, cherry-pick, bisect
- Set up GitHub account and create repositories
- Contribute to open source projects
- Learn Git workflows (Feature Branch, Gitflow)
- Configure useful aliases and tools
Remember: Git has a learning curve, but it's worth the investment. Start with the basics and gradually explore advanced features as needed.
Git Cheatsheet
Quick reference for the most commonly used Git commands.
Setup and Configuration
# Initial setup
git config --global user.name "Your Name"
git config --global user.email "your.email@example.com"
git config --global core.editor "vim"
git config --global init.defaultBranch main
# View configuration
git config --list
git config user.name
Repository Creation
# Create new repository
git init
# Clone existing repository
git clone <url>
git clone <url> <directory>
git clone -b <branch> <url>
Daily Workflow
# Check status
git status
git status -s # Short format
# Add files
git add <file>
git add . # Add all
git add -p # Interactive staging
# Commit changes
git commit -m "Message"
git commit -am "Message" # Stage and commit
git commit --amend # Modify last commit
# View history
git log
git log --oneline
git log --graph --oneline --all
# Push/Pull
git pull
git pull --rebase
git push
git push -u origin <branch>
Branching
# List branches
git branch # Local branches
git branch -a # All branches
git branch -r # Remote branches
# Create branch
git branch <name>
git checkout -b <name> # Create and switch
git switch -c <name> # Modern syntax
# Switch branches
git checkout <name>
git switch <name>
git checkout - # Previous branch
# Delete branch
git branch -d <name> # Safe delete
git branch -D <name> # Force delete
git push origin --delete <name> # Delete remote
# Rename branch
git branch -m <new_name>
Merging and Rebasing
# Merge
git merge <branch>
git merge --no-ff <branch>
git merge --squash <branch>
git merge --abort
# Rebase
git rebase <branch>
git rebase -i HEAD~3 # Interactive rebase
git rebase --continue
git rebase --abort
# Resolve conflicts
git status # Check conflicts
# Edit files to resolve
git add <resolved_file>
git commit # or git rebase --continue
Remote Operations
# List remotes
git remote -v
# Add/Remove remote
git remote add origin <url>
git remote remove <name>
git remote set-url origin <new_url>
# Fetch
git fetch
git fetch origin
git fetch --all
git fetch -p # Prune deleted branches
# Pull
git pull
git pull origin <branch>
git pull --rebase
# Push
git push
git push origin <branch>
git push -u origin <branch>
git push --tags
git push --force-with-lease # Safer force push
Viewing Changes
# Show changes
git diff # Unstaged changes
git diff --staged # Staged changes
git diff <branch1> <branch2>
git diff HEAD~1
# Show commits
git log
git log --oneline
git log -p # With patches
git log --stat # With statistics
git log --author="Name"
git log --since="2 weeks ago"
git log --grep="pattern"
# Show commit details
git show <commit>
git show <commit>:<file>
# Blame
git blame <file>
Undoing Changes
# Discard changes in working directory
git restore <file>
git restore .
git checkout -- <file> # Old syntax
# Unstage files
git restore --staged <file>
git reset HEAD <file> # Old syntax
# Undo commits
git reset --soft HEAD~1 # Keep changes staged
git reset HEAD~1 # Keep changes unstaged
git reset --hard HEAD~1 # Discard changes
git revert <commit> # Create reverting commit
# Clean untracked files
git clean -n # Dry run
git clean -f # Remove files
git clean -fd # Remove files and directories
Stashing
# Save changes temporarily
git stash
git stash save "Message"
git stash -u # Include untracked
# List stashes
git stash list
# Apply stash
git stash apply
git stash apply stash@{2}
git stash pop # Apply and remove
# Manage stashes
git stash show -p
git stash drop stash@{0}
git stash clear
Tags
# List tags
git tag
git tag -l "v1.*"
# Create tags
git tag <name> # Lightweight
git tag -a <name> -m "Message" # Annotated
# Push tags
git push origin <tag>
git push --tags
# Delete tags
git tag -d <name>
git push origin --delete <name>
# Checkout tag
git checkout <tag>
Advanced Commands
# Cherry-pick
git cherry-pick <commit>
git cherry-pick <commit1> <commit2>
# Bisect (find bug)
git bisect start
git bisect bad
git bisect good <commit>
# Test and mark good/bad
git bisect reset
# Reflog (recover lost commits)
git reflog
git checkout <commit>
# Archive
git archive --format=zip HEAD > archive.zip
# Search
git grep "pattern"
git grep -n "pattern" # With line numbers
git log -S "code" # Commits with code
Collaboration Workflows
Feature Branch Workflow
# Start new feature
git checkout -b feature/<name>
# Work on feature
git add .
git commit -m "Add feature"
# Push feature branch
git push -u origin feature/<name>
# After PR is merged
git checkout main
git pull
git branch -d feature/<name>
Sync with Upstream
# Fork workflow
git remote add upstream <original_repo_url>
git fetch upstream
git checkout main
git merge upstream/main
git push origin main
Hotfix Workflow
# Create hotfix from main
git checkout main
git checkout -b hotfix/<issue>
# Fix and commit
git add .
git commit -m "Fix issue"
# Merge to main and develop
git checkout main
git merge --no-ff hotfix/<issue>
git tag -a v1.0.1 -m "Version 1.0.1"
git checkout develop
git merge --no-ff hotfix/<issue>
# Cleanup
git branch -d hotfix/<issue>
Common Scenarios
Forgot to create branch
git stash
git checkout -b feature/<name>
git stash pop
Undo last commit but keep changes
git reset --soft HEAD~1
Amend commit message
git commit --amend -m "New message"
Remove file from Git but keep locally
git rm --cached <file>
Sync fork with original
git fetch upstream
git checkout main
git merge upstream/main
git push origin main
Squash last N commits
git rebase -i HEAD~N
# Change "pick" to "squash" for commits to combine
Change author of last commit
git commit --amend --author="Name <email>"
Create orphan branch
git switch --orphan <branch>
git commit --allow-empty -m "Initial commit"
git push -u origin <branch>
Configuration Aliases
# Create useful aliases
git config --global alias.co checkout
git config --global alias.br branch
git config --global alias.ci commit
git config --global alias.st status
git config --global alias.unstage 'reset HEAD --'
git config --global alias.last 'log -1 HEAD'
git config --global alias.lg 'log --graph --oneline --all'
git config --global alias.amend 'commit --amend --no-edit'
.gitignore Patterns
# Ignore files
*.log
*.tmp
.env
# Ignore directories
node_modules/
dist/
build/
# Ignore with exceptions
*.a
!lib.a
# IDE files
.vscode/
.idea/
*.swp
# OS files
.DS_Store
Thumbs.db
Common Options
# Flags used with multiple commands
-a, --all # All
-b, --branch # Branch
-d, --delete # Delete
-f, --force # Force
-m, --message # Message
-n, --dry-run # Dry run
-p, --patch # Interactive patch mode
-u, --set-upstream # Set upstream
-v, --verbose # Verbose output
# Common patterns
HEAD # Current commit
HEAD~1 # Previous commit
HEAD~n # N commits ago
HEAD^ # First parent of merge
<commit> # Commit hash
<branch> # Branch name
origin # Default remote name
main/master # Default branch names
Emergency Commands
# Abort everything
git merge --abort
git rebase --abort
git cherry-pick --abort
# Recover lost work
git reflog
git checkout <lost_commit>
git branch recover-branch <lost_commit>
# Undo force push (if possible)
git reflog
git reset --hard <previous_commit>
git push --force-with-lease
# Remove sensitive data from history
git filter-branch --force --index-filter \
"git rm --cached --ignore-unmatch <file>" \
--prune-empty --tag-name-filter cat -- --all
Quick Reference: Git States
Working Directory → Staging Area → Repository → Remote
(edit) (add) (commit) (push)
Quick Reference: Undoing
Working Directory: git restore <file>
Staging Area: git restore --staged <file>
Last Commit: git commit --amend
Previous Commits: git revert <commit>
Local Branch: git reset --hard <commit>
Quick Reference: Branch Management
Create: git checkout -b <name>
Switch: git switch <name>
Merge: git merge <name>
Delete: git branch -d <name>
Remote: git push -u origin <name>
Git Commands Reference
Comprehensive reference of Git commands organized by category.
Repository Setup
Initialize Repository
# Create new Git repository
git init
# Initialize with specific branch name
git init -b main
# Create bare repository (for remote)
git init --bare
Clone Repository
# Clone repository
git clone <repository_url>
# Clone to specific directory
git clone <repository_url> <directory>
# Clone specific branch
git clone -b <branch> <repository_url>
# Shallow clone (limited history)
git clone --depth 1 <repository_url>
# Clone with submodules
git clone --recursive <repository_url>
Configuration
User Settings
# Set user name
git config --global user.name "Your Name"
# Set user email
git config --global user.email "your.email@example.com"
# View user name
git config user.name
# View user email
git config user.email
Repository Settings
# Set local config (repository-specific)
git config user.name "Your Name"
# Set editor
git config --global core.editor "vim"
# Set default branch name
git config --global init.defaultBranch main
# Enable color output
git config --global color.ui auto
# Set merge strategy
git config --global pull.rebase false
# View all config
git config --list
# View specific config
git config <key>
# Edit config file
git config --global --edit
Aliases
# Create alias
git config --global alias.co checkout
git config --global alias.br branch
git config --global alias.ci commit
git config --global alias.st status
git config --global alias.last 'log -1 HEAD'
git config --global alias.visual 'log --graph --oneline --all'
Basic Operations
Status and Information
# Show working tree status
git status
# Short status
git status -s
# Show status with branch info
git status -sb
# List all tracked files
git ls-files
# Show repository info
git remote show origin
Add Files
# Add specific file
git add <file>
# Add all files
git add .
# Add all files in directory
git add <directory>/
# Add by pattern
git add *.js
# Add interactively
git add -i
# Add patch (selective staging)
git add -p
# Add all (including deleted)
git add -A
# Add modified and deleted (not new)
git add -u
Commit Changes
# Commit staged changes
git commit -m "Commit message"
# Commit with detailed message (opens editor)
git commit
# Stage all tracked files and commit
git commit -am "Commit message"
# Amend last commit
git commit --amend
# Amend without changing message
git commit --amend --no-edit
# Amend and change author
git commit --amend --author="Name <email>"
# Empty commit (no changes)
git commit --allow-empty -m "Empty commit"
# Commit with specific date
git commit --date="2024-01-01" -m "Message"
Remove and Move Files
# Remove file from working directory and staging
git rm <file>
# Remove file from staging only (keep in working directory)
git rm --cached <file>
# Remove directory
git rm -r <directory>
# Move/rename file
git mv <old_name> <new_name>
Viewing History
Log
# View commit history
git log
# Compact one-line log
git log --oneline
# Graph view
git log --graph --oneline --all
# Decorate with branch/tag names
git log --decorate
# Pretty format
git log --pretty=format:"%h - %an, %ar : %s"
# Limit number of commits
git log -n 5
git log -5
# Show commits by author
git log --author="John"
# Show commits in date range
git log --since="2 weeks ago"
git log --after="2024-01-01"
git log --until="2024-12-31"
git log --before="2024-12-31"
# Show file statistics
git log --stat
# Show detailed patch
git log -p
# Show commits affecting specific file
git log -- <file>
# Search commit messages
git log --grep="fix bug"
# Show commits that added/removed specific text
git log -S "function_name"
# Show commits by committer (not author)
git log --committer="John"
# Show merge commits only
git log --merges
# Show non-merge commits
git log --no-merges
# Show first parent only
git log --first-parent
Show Commit Details
# Show commit details
git show <commit>
# Show specific file at commit
git show <commit>:<file>
# Show commit statistics
git show --stat <commit>
# Show commit names only
git show --name-only <commit>
# Show commit with word diff
git show --word-diff <commit>
Diff
# Show unstaged changes
git diff
# Show staged changes
git diff --staged
git diff --cached
# Show changes in specific file
git diff <file>
# Compare branches
git diff <branch1>..<branch2>
# Compare commits
git diff <commit1> <commit2>
# Compare with specific commit
git diff HEAD~1
# Word-level diff
git diff --word-diff
# Show statistics only
git diff --stat
# Show file names only
git diff --name-only
# Show file names with status
git diff --name-status
# Ignore whitespace
git diff -w
git diff --ignore-all-space
Blame
# Show who changed each line
git blame <file>
# Show blame for specific lines
git blame -L 10,20 <file>
# Show blame with email
git blame -e <file>
# Ignore whitespace changes
git blame -w <file>
Reflog
# Show reference log (command history)
git reflog
# Show reflog for specific branch
git reflog <branch>
# Show reflog with dates
git reflog --date=relative
# Expire old reflog entries
git reflog expire --expire=30.days.ago --all
Branching
Create and Switch Branches
# List branches
git branch
# List all branches (including remote)
git branch -a
# List remote branches
git branch -r
# Create new branch
git branch <branch_name>
# Create branch from specific commit
git branch <branch_name> <commit>
# Switch to branch
git checkout <branch_name>
# Create and switch to new branch
git checkout -b <branch_name>
# Create branch from specific commit and switch
git checkout -b <branch_name> <commit>
# Modern syntax (Git 2.23+)
git switch <branch_name>
git switch -c <branch_name>
# Switch to previous branch
git checkout -
git switch -
Delete Branches
# Delete local branch (safe)
git branch -d <branch_name>
# Force delete local branch
git branch -D <branch_name>
# Delete remote branch
git push origin --delete <branch_name>
git push origin :<branch_name>
Rename Branches
# Rename current branch
git branch -m <new_name>
# Rename specific branch
git branch -m <old_name> <new_name>
# Rename and push to remote
git branch -m <new_name>
git push origin -u <new_name>
git push origin --delete <old_name>
Branch Information
# Show branches with last commit
git branch -v
# Show merged branches
git branch --merged
# Show unmerged branches
git branch --no-merged
# Show branches containing commit
git branch --contains <commit>
# Track remote branch
git branch --set-upstream-to=origin/<branch>
git branch -u origin/<branch>
Merging
Basic Merge
# Merge branch into current branch
git merge <branch>
# Merge with commit message
git merge <branch> -m "Merge message"
# Merge without fast-forward
git merge --no-ff <branch>
# Merge with fast-forward only
git merge --ff-only <branch>
# Squash merge (combine all commits)
git merge --squash <branch>
# Abort merge
git merge --abort
# Continue merge after resolving conflicts
git merge --continue
Merge Strategies
# Use recursive strategy (default)
git merge -s recursive <branch>
# Use ours strategy (keep our version)
git merge -s ours <branch>
# Use theirs strategy
git merge -X theirs <branch>
# Ignore whitespace during merge
git merge -X ignore-all-space <branch>
Rebasing
Basic Rebase
# Rebase current branch onto another
git rebase <branch>
# Rebase onto specific commit
git rebase <commit>
# Continue rebase after resolving conflicts
git rebase --continue
# Skip current commit
git rebase --skip
# Abort rebase
git rebase --abort
# Rebase and preserve merges
git rebase -p <branch>
Interactive Rebase
# Interactive rebase last N commits
git rebase -i HEAD~3
# Interactive rebase from specific commit
git rebase -i <commit>
# Interactive rebase with autosquash
git rebase -i --autosquash <branch>
# Commands in interactive rebase:
# pick (p) = use commit
# reword (r) = use commit, but edit message
# edit (e) = use commit, but stop for amending
# squash (s) = merge with previous commit
# fixup (f) = like squash, but discard message
# drop (d) = remove commit
# exec (x) = run shell command
Remote Operations
Remote Management
# List remotes
git remote
# List remotes with URLs
git remote -v
# Add remote
git remote add <name> <url>
# Remove remote
git remote remove <name>
git remote rm <name>
# Rename remote
git remote rename <old> <new>
# Change remote URL
git remote set-url <name> <new_url>
# Show remote details
git remote show <name>
# Prune stale remote branches
git remote prune origin
Fetch
# Fetch from remote
git fetch
# Fetch from specific remote
git fetch <remote>
# Fetch specific branch
git fetch <remote> <branch>
# Fetch all remotes
git fetch --all
# Fetch and prune deleted remote branches
git fetch -p
git fetch --prune
# Fetch tags
git fetch --tags
# Dry run (show what would be fetched)
git fetch --dry-run
Pull
# Pull from tracked remote branch
git pull
# Pull from specific remote and branch
git pull <remote> <branch>
# Pull with rebase
git pull --rebase
# Pull with fast-forward only
git pull --ff-only
# Pull all submodules
git pull --recurse-submodules
# Pull and prune
git pull -p
Push
# Push to remote
git push
# Push to specific remote and branch
git push <remote> <branch>
# Push and set upstream
git push -u <remote> <branch>
# Push all branches
git push --all
# Push tags
git push --tags
# Push specific tag
git push <remote> <tag>
# Force push (dangerous!)
git push --force
# Safer force push (checks remote state)
git push --force-with-lease
# Delete remote branch
git push <remote> --delete <branch>
# Delete remote tag
git push <remote> --delete <tag>
# Dry run (show what would be pushed)
git push --dry-run
Undoing Changes
Working Directory
# Discard changes in file
git checkout -- <file>
# Discard all changes
git checkout -- .
# Modern syntax
git restore <file>
git restore .
# Restore from specific commit
git restore --source=<commit> <file>
# Clean untracked files
git clean -f
# Clean untracked files and directories
git clean -fd
# Clean ignored files too
git clean -fdx
# Dry run (show what would be removed)
git clean -n
Staging Area
# Unstage file
git reset HEAD <file>
# Unstage all files
git reset HEAD
# Modern syntax
git restore --staged <file>
git restore --staged .
Commits
# Undo last commit, keep changes staged
git reset --soft HEAD~1
# Undo last commit, keep changes unstaged
git reset --mixed HEAD~1
git reset HEAD~1
# Undo last commit, discard changes
git reset --hard HEAD~1
# Reset to specific commit
git reset --hard <commit>
# Create new commit that undoes changes
git revert <commit>
# Revert merge commit
git revert -m 1 <merge_commit>
# Revert without committing
git revert -n <commit>
# Revert range of commits
git revert <commit1>..<commit2>
Stashing
Basic Stash
# Stash current changes
git stash
# Stash with message
git stash save "Work in progress"
git stash push -m "Work in progress"
# Stash including untracked files
git stash -u
git stash --include-untracked
# Stash including ignored files
git stash -a
git stash --all
# Stash specific files
git stash push <file>
# Stash with patch mode
git stash -p
Managing Stashes
# List stashes
git stash list
# Show stash contents
git stash show
git stash show -p
# Show specific stash
git stash show stash@{1}
# Apply last stash
git stash apply
# Apply specific stash
git stash apply stash@{1}
# Apply and remove stash (pop)
git stash pop
# Pop specific stash
git stash pop stash@{1}
# Create branch from stash
git stash branch <branch_name>
# Drop specific stash
git stash drop stash@{1}
# Clear all stashes
git stash clear
Tags
Create Tags
# List tags
git tag
# List tags with pattern
git tag -l "v1.*"
# Create lightweight tag
git tag <tag_name>
# Create annotated tag
git tag -a <tag_name> -m "Tag message"
# Tag specific commit
git tag <tag_name> <commit>
# Tag with specific date
git tag -a <tag_name> -m "Message" --date="2024-01-01"
Manage Tags
# Show tag details
git show <tag_name>
# Delete local tag
git tag -d <tag_name>
# Delete remote tag
git push origin --delete <tag_name>
git push origin :refs/tags/<tag_name>
# Push tag to remote
git push origin <tag_name>
# Push all tags
git push --tags
# Fetch tags from remote
git fetch --tags
# Checkout tag (creates detached HEAD)
git checkout <tag_name>
# Create branch from tag
git checkout -b <branch_name> <tag_name>
Advanced Operations
Cherry-Pick
# Apply specific commit
git cherry-pick <commit>
# Apply multiple commits
git cherry-pick <commit1> <commit2>
# Apply commit range
git cherry-pick <commit1>..<commit2>
# Cherry-pick without committing
git cherry-pick -n <commit>
# Continue cherry-pick
git cherry-pick --continue
# Abort cherry-pick
git cherry-pick --abort
# Skip current commit
git cherry-pick --skip
Bisect
# Start bisect
git bisect start
# Mark current commit as bad
git bisect bad
# Mark current commit as good
git bisect good
# Mark specific commit as good
git bisect good <commit>
# Skip current commit
git bisect skip
# Reset bisect
git bisect reset
# Visualize bisect
git bisect visualize
# Run automated bisect
git bisect run <script>
Submodules
# Add submodule
git submodule add <repository_url> <path>
# Initialize submodules
git submodule init
# Update submodules
git submodule update
# Clone with submodules
git clone --recursive <repository_url>
# Update all submodules
git submodule update --remote
# Remove submodule
git submodule deinit <path>
git rm <path>
# Show submodule status
git submodule status
# Foreach command on all submodules
git submodule foreach <command>
Worktrees
# List worktrees
git worktree list
# Add new worktree
git worktree add <path> <branch>
# Add worktree with new branch
git worktree add -b <new_branch> <path>
# Remove worktree
git worktree remove <path>
# Prune worktree information
git worktree prune
Archive
# Create archive of repository
git archive --format=zip HEAD > archive.zip
# Archive specific branch
git archive --format=tar <branch> > archive.tar
# Archive with prefix
git archive --prefix=project/ HEAD > archive.tar
# Archive specific directory
git archive HEAD <directory>/ > archive.tar
Maintenance
Repository Maintenance
# Run garbage collection
git gc
# Aggressive garbage collection
git gc --aggressive
# Prune unreachable objects
git prune
# Verify repository integrity
git fsck
# Show repository statistics
git count-objects -v
# Show repository size
git count-objects -vH
Optimization
# Repack repository
git repack
# Aggressive repack
git repack -a -d --depth=250 --window=250
# Prune old reflog entries
git reflog expire --expire=30.days.ago --all
# Remove old objects
git prune --expire=30.days.ago
Searching
Grep
# Search for text in repository
git grep "pattern"
# Search with line numbers
git grep -n "pattern"
# Search for whole word
git grep -w "pattern"
# Search case-insensitively
git grep -i "pattern"
# Search in specific commit
git grep "pattern" <commit>
# Search with context
git grep -C 2 "pattern"
# Show file names only
git grep -l "pattern"
# Count matches per file
git grep -c "pattern"
# Search with AND condition
git grep -e "pattern1" --and -e "pattern2"
# Search with OR condition
git grep -e "pattern1" --or -e "pattern2"
Log Search
# Search commit messages
git log --grep="pattern"
# Search commit content
git log -S "code"
# Search with pickaxe (show diff)
git log -G "regex"
# Search author
git log --author="name"
# Search committer
git log --committer="name"
Help
# Show help for command
git help <command>
git <command> --help
# Show quick help
git <command> -h
# Show all commands
git help -a
# Show guides
git help -g
# Show config options
git help config
Git Internals
Overview
Git is often described as a "content-addressable filesystem with a VCS user interface on top." Understanding Git's internal architecture reveals how it efficiently stores data, tracks changes, and enables powerful version control operations. This guide explores the plumbing commands, internal data structures, and core concepts that make Git work.
Why Learn Git Internals?
- Debug complex issues more effectively
- Understand what commands actually do
- Recover from disasters
- Optimize repository performance
- Build custom Git tools and automation
The .git Directory
Every Git repository has a .git directory containing all Git metadata and objects.
Directory Structure
.git/
├── HEAD # Points to current branch
├── config # Repository-specific configuration
├── description # Repository description (for GitWeb)
├── index # Staging area (binary file)
├── hooks/ # Client and server-side hook scripts
├── info/ # Global exclude file and refs
│ └── exclude # gitignore patterns not in .gitignore
├── objects/ # All content: commits, trees, blobs, tags
│ ├── pack/ # Packfiles for efficient storage
│ └── info/ # Object info and packs
├── refs/ # References (branches and tags)
│ ├── heads/ # Local branches
│ ├── remotes/ # Remote-tracking branches
│ └── tags/ # Tags
├── logs/ # Reflog information
│ ├── HEAD # HEAD history
│ └── refs/ # Branch history
└── packed-refs # Packed references for performance
Exploring .git Directory
# Navigate to .git
cd .git
# View HEAD (current branch pointer)
cat HEAD
# Output: ref: refs/heads/main
# View current branch
cat refs/heads/main
# Output: a3f2b1c... (commit SHA-1)
# View remote branch
cat refs/remotes/origin/main
# List all objects
find objects/ -type f
Git Objects: The Building Blocks
Git stores everything as objects identified by SHA-1 hashes. There are four object types:
- Blob - File content
- Tree - Directory structure
- Commit - Snapshot with metadata
- Tag - Annotated tag with metadata
Object Storage
Objects are stored in .git/objects/:
- First 2 characters of SHA-1 = subdirectory
- Remaining 38 characters = filename
- Content is zlib-compressed
# Example: Object a3f2b1c4...
# Stored at: .git/objects/a3/f2b1c4...
1. Blob Objects
Blobs store file content (data only, no filename or metadata).
# Create a blob manually (plumbing)
echo "Hello, Git!" | git hash-object -w --stdin
# Output: 8ab686eafeb1f44702738c8b0f24f2567c36da6d
# -w = write to object database
# --stdin = read from standard input
# View blob content
git cat-file -p 8ab686eafeb1f44702738c8b0f24f2567c36da6d
# Output: Hello, Git!
# Check object type
git cat-file -t 8ab686ea
# Output: blob
# View object size
git cat-file -s 8ab686ea
# Output: 12
Creating blobs from files:
# Create a file
echo "Git internals are fascinating" > test.txt
# Hash and store the file
git hash-object -w test.txt
# Output: 2c8b4e3b7c1a9f...
# The blob is now in .git/objects/
# File content is stored, but filename is NOT
2. Tree Objects
Trees represent directory structure, mapping filenames to blobs and other trees.
# View a tree object
git cat-file -p main^{tree}
# Output:
# 100644 blob a3f2b1c... README.md
# 100644 blob 7e4d3a2... index.js
# 040000 tree 9c1f5b8... src
# Tree entries format:
# <mode> <type> <sha-1> <filename>
File Modes:
100644- Normal file100755- Executable file120000- Symbolic link040000- Directory (tree)160000- Gitlink (submodule)
Creating a tree manually:
# Create a tree from index
git write-tree
# Output: 9c1f5b8a... (tree SHA-1)
# Add files to index first
git update-index --add --cacheinfo 100644 \
a3f2b1c4... README.md
git update-index --add --cacheinfo 100644 \
7e4d3a2b... index.js
# Write tree from current index
git write-tree
Reading tree contents:
# List tree contents recursively
git ls-tree -r -t main^{tree}
# -r = recursive
# -t = show trees as well
# Pretty print tree structure
git ls-tree --abbrev main^{tree}
3. Commit Objects
Commits point to a tree (snapshot) and contain metadata.
# View commit object
git cat-file -p HEAD
# Output:
# tree 9c1f5b8a...
# parent a3f2b1c4...
# author John Doe <john@example.com> 1234567890 -0500
# committer John Doe <john@example.com> 1234567890 -0500
#
# Commit message here
Commit Structure:
tree- Points to root tree (project snapshot)parent- Previous commit(s); merge commits have multipleauthor- Who wrote the code (name, email, timestamp)committer- Who committed (may differ from author)- Commit message
Creating a commit manually:
# Create a commit (plumbing)
echo "Initial commit" | git commit-tree 9c1f5b8a
# Output: b4e3c2d1... (commit SHA-1)
# Create commit with parent
echo "Second commit" | git commit-tree 7a2b3c4d -p b4e3c2d1
# Update branch to point to new commit
git update-ref refs/heads/main b4e3c2d1
4. Tag Objects
Annotated tags are objects containing metadata about a tag.
# Create annotated tag
git tag -a v1.0 -m "Version 1.0"
# View tag object
git cat-file -p v1.0
# Output:
# object a3f2b1c4...
# type commit
# tag v1.0
# tagger John Doe <john@example.com> 1234567890 -0500
#
# Version 1.0
Lightweight vs Annotated Tags:
# Lightweight tag (just a ref)
git tag v1.0-light
cat .git/refs/tags/v1.0-light
# Output: a3f2b1c4... (points directly to commit)
# Annotated tag (object)
git tag -a v1.0 -m "Release"
cat .git/refs/tags/v1.0
# Output: b7e8f3a... (points to tag object)
Content-Addressable Storage
Git uses SHA-1 hashing to create content-addressable storage.
How SHA-1 Works in Git
# Git computes SHA-1 of:
# "blob <size>\0<content>"
# Example calculation
content="Hello, Git!"
size=${#content}
(printf "blob %s\0" $size; echo -n "$content") | sha1sum
# Output: 8ab686eafeb1f44702738c8b0f24f2567c36da6d
# This matches what git hash-object produces
echo "Hello, Git!" | git hash-object --stdin
Properties of Content-Addressable Storage
- Deduplication: Identical content = same hash = stored once
- Integrity: SHA-1 acts as checksum; corruption is detectable
- Immutability: Can't change content without changing hash
- Efficient: Easy to check if object exists (hash lookup)
# Example: Two files with identical content
echo "Same content" > file1.txt
echo "Same content" > file2.txt
# Both produce same blob
git hash-object file1.txt # abc123...
git hash-object file2.txt # abc123... (identical!)
# Git stores content only once
File Tracking and the Index
The index (staging area) is a binary file at .git/index that serves as a staging area between the working directory and repository.
The Three States of Files
┌─────────────────┐ git add ┌─────────────────┐ git commit ┌─────────────────┐
│ Working Dir │──────────────→│ Staging Area │───────────────→│ Repository │
│ (modified) │ │ (staged) │ │ (committed) │
└─────────────────┘ └─────────────────┘ └─────────────────┘
↑ │
└─────────────────────────────── git checkout ────────────────────────┘
File States
- Untracked: Not in index or last commit
- Unmodified: In repository, unchanged
- Modified: Changed since last commit
- Staged: Marked for next commit
# File lifecycle diagram
Untracked ──add──→ Staged ──commit──→ Unmodified
↑ │
│ │ edit
│ ↓
└────────────────Modified
Viewing the Index
# View index contents
git ls-files --stage
# Output:
# 100644 a3f2b1c... 0 README.md
# 100644 7e4d3a2... 0 src/index.js
# 100644 9f8e7d6... 0 package.json
# Format: <mode> <sha-1> <stage> <filename>
# stage: 0 = normal, 1-3 = conflict resolution stages
Index Stages (for merge conflicts):
- Stage 0: Normal entry
- Stage 1: Common ancestor version
- Stage 2: "ours" (current branch)
- Stage 3: "theirs" (merging branch)
# During merge conflict
git ls-files --stage
# 100644 a1b2c3... 1 conflicted.txt (base)
# 100644 d4e5f6... 2 conflicted.txt (ours)
# 100644 g7h8i9... 3 conflicted.txt (theirs)
Working with the Index
# Add file to index
git update-index --add --cacheinfo 100644 a3f2b1c README.md
# Remove from index (keep in working dir)
git update-index --force-remove README.md
# Refresh index (update stat info)
git update-index --refresh
# Show index and working tree differences
git diff-files
# Shows files modified in working dir
# Show index and repository differences
git diff-index --cached HEAD
# Shows staged changes
How git add Works Internally
# When you run: git add file.txt
# 1. Git computes SHA-1 of file content
hash=$(git hash-object -w file.txt)
# 2. Stores blob in .git/objects/
# (Already done by -w flag above)
# 3. Updates index with new hash
git update-index --add --cacheinfo 100644 $hash file.txt
# This is what git add does behind the scenes!
Refs: Pointers to Commits
References (refs) are human-readable names that point to commits. They're stored in .git/refs/.
Types of Refs
- Heads (branches):
.git/refs/heads/ - Tags:
.git/refs/tags/ - Remotes:
.git/refs/remotes/
# View a ref (just a file with commit SHA-1)
cat .git/refs/heads/main
# Output: a3f2b1c4d5e6f7g8h9i0j1k2l3m4n5o6p7q8r9s0
# All refs are just text files!
HEAD: The Current Reference
HEAD is a symbolic reference pointing to the current branch.
# View HEAD
cat .git/HEAD
# Output: ref: refs/heads/main
# HEAD points to a branch, which points to a commit
cat .git/refs/heads/main
# Output: a3f2b1c...
Normal HEAD (attached):
HEAD → refs/heads/main → commit a3f2b1c
Detached HEAD:
HEAD → commit a3f2b1c (no branch)
Detached HEAD State
# Checkout specific commit
git checkout a3f2b1c
# Warning: You are in 'detached HEAD' state
# HEAD now points directly to commit
cat .git/HEAD
# Output: a3f2b1c... (no longer "ref: refs/heads/...")
# Any commits made here are "orphaned" unless you create a branch
git switch -c new-branch # Attach HEAD to new branch
Symbolic References
# HEAD is a symbolic ref
git symbolic-ref HEAD
# Output: refs/heads/main
# Change HEAD to point to different branch
git symbolic-ref HEAD refs/heads/develop
# Now on develop branch (without checking out files)
# Read the ref HEAD points to
git symbolic-ref HEAD
git rev-parse HEAD # Get the commit SHA-1
Special Refs
- HEAD: Current commit/branch
- ORIG_HEAD: Previous HEAD (before risky operations)
- FETCH_HEAD: Last fetched branch
- MERGE_HEAD: Commit being merged
- CHERRY_PICK_HEAD: Commit being cherry-picked
# ORIG_HEAD is set by commands that move HEAD
git reset --hard HEAD~1 # ORIG_HEAD now points to previous HEAD
# Undo the reset
git reset --hard ORIG_HEAD
# MERGE_HEAD during merge
git merge feature-branch
# .git/MERGE_HEAD exists during merge conflict
cat .git/MERGE_HEAD # Shows commit being merged
Creating and Managing Refs
# Create a branch (low-level)
git update-ref refs/heads/new-branch a3f2b1c
# This is what git branch does!
# Equivalent to:
echo "a3f2b1c..." > .git/refs/heads/new-branch
# Delete a ref
git update-ref -d refs/heads/old-branch
# List all refs
git for-each-ref
# Output:
# a3f2b1c... commit refs/heads/main
# b4e3c2d... commit refs/heads/feature
# 7a2b3c4... commit refs/remotes/origin/main
# 9f8e7d6... tag refs/tags/v1.0
# Format output
git for-each-ref --format='%(refname:short) %(objecttype) %(objectname:short)'
Packed References
For performance, Git can pack refs into .git/packed-refs.
# View packed refs
cat .git/packed-refs
# Output:
# # pack-refs with: peeled fully-peeled sorted
# a3f2b1c... refs/heads/main
# b4e3c2d... refs/remotes/origin/main
# 7a2b3c4... refs/tags/v1.0
# ^9f8e7d6... (peeled tag - points to commit, not tag object)
# Pack refs manually
git pack-refs --all --prune
# Loose refs take precedence over packed refs
Plumbing vs Porcelain Commands
Git commands are divided into two categories:
- Porcelain: High-level user-friendly commands (
git commit,git push) - Plumbing: Low-level commands that manipulate Git internals
Why Plumbing Commands?
- Automation: Build scripts and tools
- Understanding: Learn how Git works
- Recovery: Fix broken repositories
- Debugging: Investigate issues
Essential Plumbing Commands
Object Inspection
# cat-file: View object content
git cat-file -t a3f2b1c # Type (blob, tree, commit, tag)
git cat-file -s a3f2b1c # Size
git cat-file -p a3f2b1c # Pretty-print content
git cat-file blob a3f2b1c # View blob content
# rev-parse: Parse revisions
git rev-parse HEAD # Full SHA-1 of HEAD
git rev-parse --short HEAD # Short SHA-1
git rev-parse main # Resolve branch to commit
git rev-parse HEAD~3 # Three commits before HEAD
# ls-tree: List tree contents
git ls-tree HEAD # Root tree
git ls-tree -r HEAD # Recursive
git ls-tree HEAD src/ # Specific directory
Object Creation
# hash-object: Create blob
echo "content" | git hash-object -w --stdin
git hash-object -w file.txt
# mktree: Create tree from stdin
# Format: <mode> SP <type> SP <sha1> TAB <filename>
cat | git mktree << EOF
100644 blob a3f2b1c... file1.txt
100644 blob b4e3c2d... file2.txt
040000 tree 7a2b3c4... subdir
EOF
# commit-tree: Create commit
echo "Commit message" | git commit-tree 9c1f5b8 -p a3f2b1c
# write-tree: Create tree from index
git write-tree
Reference Management
# update-ref: Create/update references
git update-ref refs/heads/test a3f2b1c
git update-ref -d refs/heads/test # Delete
# symbolic-ref: Manage symbolic refs
git symbolic-ref HEAD refs/heads/main
# for-each-ref: Iterate over refs
git for-each-ref refs/heads/
git for-each-ref --format='%(refname)' refs/tags/
Index Manipulation
# update-index: Modify index
git update-index --add --cacheinfo 100644 a3f2b1c file.txt
git update-index --remove file.txt
git update-index --refresh
# ls-files: Show index contents
git ls-files # All tracked files
git ls-files --stage # With hash and mode
git ls-files --deleted # Deleted in working dir
git ls-files --modified # Modified in working dir
git ls-files --others # Untracked files
# read-tree: Read tree into index
git read-tree HEAD # Reset index to HEAD
git read-tree --prefix=sub/ HEAD # Read into subdirectory
Comparison and Diffing
# diff-tree: Compare trees
git diff-tree HEAD HEAD~1 # Compare commits
git diff-tree -r HEAD HEAD~1 # Recursive
# diff-index: Compare index
git diff-index HEAD # Index vs HEAD
git diff-index --cached HEAD # Staged changes
# diff-files: Compare working dir
git diff-files # Working dir vs index
Building Porcelain with Plumbing
Example: Implementing git add with plumbing
#!/bin/bash
# add.sh - Simplified git add implementation
file=$1
# 1. Hash and store file content
hash=$(git hash-object -w "$file")
# 2. Update index
git update-index --add --cacheinfo 100644 "$hash" "$file"
echo "Added $file (hash: $hash)"
Example: Implementing git commit with plumbing
#!/bin/bash
# commit.sh - Simplified git commit implementation
message=$1
# 1. Create tree from current index
tree=$(git write-tree)
# 2. Get parent commit
parent=$(git rev-parse HEAD)
# 3. Create commit object
commit=$(echo "$message" | git commit-tree "$tree" -p "$parent")
# 4. Update branch ref
git update-ref refs/heads/$(git rev-parse --abbrev-ref HEAD) "$commit"
echo "Created commit $commit"
Commit Ancestry and References
Git uses special syntax to refer to commits relative to each other.
Ancestry References
# Parent references
HEAD~1 # First parent (same as HEAD^)
HEAD~2 # Second parent (grandparent)
HEAD~3 # Third parent (great-grandparent)
# Multiple parents (merge commits)
HEAD^1 # First parent
HEAD^2 # Second parent (merged branch)
# Combining
HEAD~2^2 # Second parent of grandparent
Difference between ~ and ^:
~always follows first parent^can select which parent
# Merge commit example
A
/ \
B C
|
D
# A~1 = B (first parent)
# A^1 = B (first parent)
# A^2 = C (second parent)
# A~2 = D (grandparent via first parent)
Commit Ranges
# Double dot: Commits in B not in A
git log A..B
# Example: main..feature (commits in feature not in main)
# Triple dot: Commits in A or B, but not both
git log A...B
# Example: main...feature (symmetric difference)
# All ancestors of B excluding A
git log ^A B
git log A..B # Equivalent
# Multiple exclusions
git log ^A ^B C
# Commits in C but not in A or B
Practical examples:
# View commits in feature branch not in main
git log main..feature
# View commits that will be pushed
git log origin/main..HEAD
# View commits in current branch since branching from main
git log main..HEAD
# Show what changed between two branches
git log --oneline main...feature
# Find merge base
git merge-base main feature
Refspecs
Refspecs define mappings between remote and local refs.
# View refspec
git config --get-regexp remote.origin
# Output:
# remote.origin.url https://github.com/user/repo.git
# remote.origin.fetch +refs/heads/*:refs/remotes/origin/*
# Refspec format:
# [+]<source>:<destination>
# + = force update
Fetch refspec: +refs/heads/*:refs/remotes/origin/*
- Maps all remote branches to local remote-tracking branches
refs/heads/main→refs/remotes/origin/main
Push refspec: refs/heads/*:refs/heads/*
- Maps local branches to remote branches
refs/heads/main→refs/heads/main(on remote)
# Custom refspec examples
# Fetch only main branch
git config remote.origin.fetch refs/heads/main:refs/remotes/origin/main
# Fetch all branches (default)
git config remote.origin.fetch '+refs/heads/*:refs/remotes/origin/*'
# Push branch to different name
git push origin local-branch:remote-branch
# Push to different ref
git push origin HEAD:refs/heads/new-branch
# Delete remote branch
git push origin :branch-to-delete
# Or
git push origin --delete branch-to-delete
# Fetch pull request (GitHub)
git fetch origin pull/123/head:pr-123
The Reflog
The reflog records when refs (HEAD, branches) were updated. It's essential for recovery.
Understanding Reflog
# View HEAD reflog
git reflog
# Output:
# a3f2b1c (HEAD -> main) HEAD@{0}: commit: Add feature
# b4e3c2d HEAD@{1}: commit: Fix bug
# 7a2b3c4 HEAD@{2}: checkout: moving from dev to main
# Reflog for specific ref
git reflog show main
git reflog show origin/main
Reflog Syntax
# @{n} - nth prior value
HEAD@{0} # Current HEAD
HEAD@{1} # Previous HEAD
HEAD@{2} # Two steps back
# @{time} - Value at specific time
HEAD@{5.minutes.ago}
HEAD@{yesterday}
HEAD@{2.days.ago}
HEAD@{2023-01-01}
# Examples
git show HEAD@{5} # Show 5th prior HEAD
git diff HEAD@{0} HEAD@{1} # Compare current vs previous
git log -g HEAD # Show reflog as log
Recovery with Reflog
# Scenario: Accidentally reset hard
git reset --hard HEAD~3 # Oops!
# Find commit before reset
git reflog
# a3f2b1c HEAD@{0}: reset: moving to HEAD~3
# b4e3c2d HEAD@{1}: commit: Lost commit
# Recover
git reset --hard HEAD@{1}
# Or
git reset --hard b4e3c2d
# Recover deleted branch
git reflog --all # Show all refs
git branch recovered-branch a3f2b1c
Scenario: Recover from bad rebase
# Before rebase
git log --oneline
# a3f2b1c (HEAD -> feature) Feature work
# b4e3c2d More feature work
# 7a2b3c4 (main) Main work
# Bad interactive rebase (dropped commits)
git rebase -i main
# Accidentally deleted commits!
# Find commits in reflog
git reflog
# 9f8e7d6 HEAD@{0}: rebase -i: finish
# a3f2b1c HEAD@{1}: rebase -i: start
# Reset to before rebase
git reset --hard HEAD@{1}
Reflog Expiration
# Reflogs are temporary (default: 90 days)
# Unreachable commits expire after 30 days
# View reflog expiration config
git config --get gc.reflogExpire # Default: 90 days
git config --get gc.reflogExpireUnreachable # Default: 30 days
# Manually expire reflog
git reflog expire --expire=now --all
git gc --prune=now
# Keep reflog forever (not recommended)
git config gc.reflogExpire never
Branches Under the Hood
Branches are simply refs pointing to commits. Understanding this reveals Git's power.
What is a Branch?
# A branch is just a file containing a commit hash
cat .git/refs/heads/main
# Output: a3f2b1c4d5e6f7g8h9i0j1k2l3m4n5o6p7q8r9s0
# That's it! Just 40 bytes (or less in packed-refs)
Creating Branches with Plumbing
# Porcelain
git branch new-feature
# Plumbing equivalent
git update-ref refs/heads/new-feature HEAD
# Or even more manual
echo $(git rev-parse HEAD) > .git/refs/heads/new-feature
Switching Branches
# Porcelain
git checkout main
# Plumbing steps:
# 1. Update HEAD
git symbolic-ref HEAD refs/heads/main
# 2. Update index and working directory
git read-tree --reset -u HEAD
# 3. That's it!
Merging: Fast-Forward vs Three-Way
Fast-forward merge:
Before:
main feature
↓ ↓
A - B - C - D
After (git merge feature):
main/feature
↓
A - B - C - D
# Fast-forward is just updating the ref
git update-ref refs/heads/main $(git rev-parse feature)
Three-way merge:
Before:
C (main)
/
A - B
\
D (feature)
After (git merge feature):
C - M (main)
/ /
A - B - D
# Three-way merge creates new commit with two parents
# Parents: current HEAD and merged branch
# Plumbing equivalent:
tree=$(git write-tree) # From merge resolution
commit=$(echo "Merge message" | git commit-tree $tree \
-p $(git rev-parse HEAD) \
-p $(git rev-parse feature))
git update-ref refs/heads/main $commit
Remote Tracking
Understanding how Git tracks remote branches.
Remote-Tracking Branches
# Remote-tracking branches are refs under refs/remotes/
ls -la .git/refs/remotes/origin/
# main
# develop
# feature-123
# They're just refs, like local branches
cat .git/refs/remotes/origin/main
# b4e3c2d1... (commit hash)
How git fetch Works
# Porcelain
git fetch origin
# Plumbing steps:
# 1. Connect to remote
# 2. Receive pack of new objects
# 3. Store objects in .git/objects/
# 4. Update refs/remotes/origin/* refs
# Fetch specific branch
git fetch origin main:refs/remotes/origin/main
How git pull Works
# git pull = git fetch + git merge
# Equivalent to:
git fetch origin
git merge origin/main
# Or with rebase:
git fetch origin
git rebase origin/main
How git push Works
# Porcelain
git push origin main
# Plumbing steps:
# 1. Check if fast-forward possible
# 2. Pack objects not on remote
# 3. Send pack to remote
# 4. Remote updates refs/heads/main
# Push creates commits on remote, then updates ref
# Equivalent refspec:
git push origin refs/heads/main:refs/heads/main
Tracking Branches
# Set upstream branch
git branch --set-upstream-to=origin/main main
# This adds to .git/config:
# [branch "main"]
# remote = origin
# merge = refs/heads/main
# View tracking relationship
git branch -vv
# main a3f2b1c [origin/main] Latest commit
# Remote tracking allows:
git pull # Knows to pull from origin/main
git push # Knows to push to origin/main
Pack Files and Storage Optimization
Git uses pack files to compress objects efficiently.
Loose vs Packed Objects
Loose objects:
- Individual files in
.git/objects/ab/cdef... - Zlib-compressed
- One object per file
- Fast to create, slower to access in bulk
Packed objects:
- Combined into
.git/objects/pack/pack-*.pack - Delta-compressed (stores differences)
- Accompanied by
.idxindex file - Slower to create, much faster to access
Viewing Object Storage
# Count objects
git count-objects -v
# Output:
# count: 150 # Loose objects
# size: 600 # KB
# in-pack: 3500 # Packed objects
# packs: 1 # Number of pack files
# size-pack: 1200 # KB in packs
# prune-packable: 0
# garbage: 0
# size-garbage: 0
# List pack files
ls -lh .git/objects/pack/
# pack-abc123.idx
# pack-abc123.pack
Pack File Structure
# .pack file: Contains compressed objects
# .idx file: Index for finding objects in pack
# Verify pack
git verify-pack -v .git/objects/pack/pack-*.idx
# Output:
# a3f2b1c blob 150 140 12
# b4e3c2d blob 200 185 152
# 7a2b3c4 commit 250 235 337
# ...
# non delta: 150 objects
# chain length = 10: 50 objects
Delta Compression
Git stores deltas (differences) to save space:
# Example: Two similar files
# version1.txt: "Hello World"
# version2.txt: "Hello World!\nNew line"
# Git stores:
# - Full version2.txt (base)
# - Delta: version1 relative to version2
# Verify pack shows delta chains
git verify-pack -v .git/objects/pack/pack-*.idx | grep chain
Garbage Collection
# Manual garbage collection
git gc
# - Packs loose objects
# - Removes unreachable objects
# - Optimizes repository
# Aggressive GC (slow but thorough)
git gc --aggressive
# More thorough delta compression
# Prune unreachable objects
git prune
# Remove objects not reachable from any ref
# Prune everything older than 2 weeks
git prune --expire=2.weeks.ago
# Automatic GC
git config gc.auto 6700 # Auto-gc after 6700 loose objects
git config gc.autopacklimit 50 # Auto-gc after 50 pack files
Optimizing Repository
# Create pack file from scratch
git repack -a -d -f
# -a = all objects
# -d = remove redundant packs
# -f = force
# Aggressive repacking
git repack -a -d -f --depth=250 --window=250
# Reduce repository size
git gc --aggressive --prune=now
# Clone with shallow history (for large repos)
git clone --depth 1 <url>
# Only most recent commit
Advanced Internals Topics
The Index File Format
The index is a binary file with this structure:
Header (12 bytes):
- Signature: "DIRC" (DIrectory Cache)
- Version: 2, 3, or 4
- Number of entries
Entry (variable length):
- ctime/mtime metadata
- Device/inode
- Mode (file permissions)
- UID/GID
- File size
- SHA-1 (20 bytes)
- Flags (name length, stage)
- File name
Extensions:
- Tree cache
- Resolve undo
- etc.
# Dump index in human-readable format
git ls-files --stage --debug
Object Database Deep Dive
# Find all objects
find .git/objects -type f
# Object file structure:
# - zlib compressed
# - Header: "<type> <size>\0"
# - Content
# Decompress object manually (example)
printf "\x1f\x8b\x08\x00\x00\x00\x00\x00" | \
cat - .git/objects/ab/cdef... | \
gunzip
# Or use Git's plumbing
git cat-file -p abcdef
Git Hooks and Plumbing
Hooks are scripts in .git/hooks/ that run at specific points.
# Example: pre-commit hook using plumbing
# .git/hooks/pre-commit
#!/bin/bash
# Check for TODO comments in staged files
for file in $(git diff-index --cached --name-only HEAD); do
if git cat-file -p :0:$file | grep -q "TODO"; then
echo "Error: TODO found in $file"
exit 1
fi
done
Inspecting Repository Health
# Check repository integrity
git fsck
# - Verifies object connectivity
# - Checks for corruption
# - Reports dangling/unreachable objects
# Full check
git fsck --full
# Output:
# Checking object directories: 100% (256/256), done.
# Checking objects: 100% (3456/3456), done.
# dangling commit abc123...
# Find large objects
git rev-list --objects --all | \
git cat-file --batch-check='%(objecttype) %(objectname) %(objectsize) %(rest)' | \
awk '/^blob/ {print substr($0,6)}' | \
sort -nk2 | \
tail -20
Practical Plumbing Use Cases
1. Find When File Was Deleted
# Find when file was deleted
git log --all --full-history -- deleted-file.txt
# Using plumbing
git rev-list --all -- deleted-file.txt | while read commit; do
if ! git ls-tree -r $commit | grep -q deleted-file.txt; then
echo "Deleted in: $commit"
break
fi
done
2. Extract File from History
# Get file from specific commit
commit="abc123"
file="path/to/file.txt"
# Find blob hash
blob=$(git ls-tree $commit $file | awk '{print $3}')
# Extract content
git cat-file blob $blob > recovered-file.txt
3. Rewrite History to Remove Sensitive Data
# Remove file from all commits (using plumbing concepts)
git filter-branch --tree-filter 'rm -f passwords.txt' HEAD
# Or with plumbing (manual approach for understanding):
git rev-list --all | while read commit; do
# Get tree
tree=$(git rev-parse $commit^{tree})
# Create new tree without sensitive file
# (Complex - requires manual tree manipulation)
# Create new commit
new_commit=$(git commit-tree ...)
# Update refs
git update-ref ...
done
4. Create Orphan Branch
# Porcelain
git checkout --orphan new-root
# Plumbing equivalent
# Create empty tree
empty_tree=$(git hash-object -t tree /dev/null)
# Create first commit
commit=$(echo "Initial" | git commit-tree $empty_tree)
# Create branch
git update-ref refs/heads/new-root $commit
# Switch to branch
git symbolic-ref HEAD refs/heads/new-root
git reset --hard
5. Analyze Repository Statistics
# Count commits per author
git rev-list --all --pretty=format:'%an' | \
grep -v '^commit' | \
sort | uniq -c | sort -nr
# Find largest commits
git rev-list --objects --all | \
git cat-file --batch-check='%(objecttype) %(objectsize) %(rest)' | \
grep '^commit' | \
sort -k2 -n | \
tail -10
# List all files ever committed
git rev-list --objects --all | \
grep -v '^commit' | \
cut -d' ' -f2- | \
sort -u
Debugging with Plumbing
Trace Git Commands
# See what Git is doing
GIT_TRACE=1 git commit -m "Test"
# Output shows underlying commands
# Trace pack operations
GIT_TRACE_PACK_ACCESS=1 git fetch
# Trace performance
GIT_TRACE_PERFORMANCE=1 git status
Verbose Object Information
# Find object by content
echo "search content" | git hash-object --stdin
# Check if object exists
git cat-file -e abc123 && echo "exists"
# Batch check objects
echo -e "abc123\ndef456\n789abc" | \
git cat-file --batch-check
# Follow rename history
git log --follow --all -- file.txt
Best Practices
-
Don't Modify .git Manually
- Use plumbing commands instead
- Prevents corruption
-
Understand Before Using
- Plumbing commands can be destructive
- Test in disposable repositories first
-
Use Reflog for Safety
- Reflog can recover from mistakes
- Keep reflog enabled
-
Regular Maintenance
- Run
git gcperiodically - Check health with
git fsck
- Run
-
Backup Before Experiments
cp -r .git .git.backup- Or use separate clone
-
Learn Incrementally
- Start with inspection commands
- Progress to modification commands
- Master recovery techniques
Summary
Git's internals are elegant and understandable:
- Objects (blob, tree, commit, tag) are the foundation
- Refs are pointers to commits
- Index bridges working directory and repository
- Plumbing commands manipulate these primitives directly
- Pack files optimize storage
- Reflog enables recovery
Understanding internals empowers you to:
- Debug complex issues
- Recover from disasters
- Build custom automation
- Optimize repository performance
- Contribute to Git itself
The next time a porcelain command behaves unexpectedly, you'll understand why and how to fix it using plumbing commands.
Resources
Official Documentation
Advanced Topics
Tools
- git-sizer - Analyze repository size
- BFG Repo-Cleaner - Remove sensitive data
- git-filter-repo - Rewrite history
Visualization
Remember: With great power (plumbing commands) comes great responsibility. Always have backups!
GitHub
GitHub is a web-based platform that provides hosting for Git repositories along with collaboration features, CI/CD, project management, and more.
Quick Start
SSH Setup
# Generate SSH key
ssh-keygen -t ed25519 -C "your.email@example.com"
# Start SSH agent
eval "$(ssh-agent -s)"
# Add SSH key to agent
ssh-add ~/.ssh/id_ed25519
# Copy public key to clipboard
cat ~/.ssh/id_ed25519.pub
# Then paste in GitHub Settings → SSH and GPG keys → New SSH key
Clone Repository
# Using SSH (recommended)
git clone git@github.com:<username>/<repository>.git
# Using HTTPS
git clone https://github.com/<username>/<repository>.git
# Create new branch
git switch -c <new_branch>
# Push branch to remote
git push -u origin <new_branch>
Pull Request Workflow
Creating a Pull Request
# 1. Create and switch to feature branch
git checkout -b feature/new-feature
# 2. Make changes and commit
git add .
git commit -m "feat: Add new feature"
# 3. Push branch to GitHub
git push -u origin feature/new-feature
# 4. Open browser and create PR
# Navigate to repository → Pull requests → New pull request
# Select your branch → Create pull request
# Or use GitHub CLI
gh pr create --title "Add new feature" --body "Description of changes"
Working on Pull Request
# After creating PR, make additional commits
git add .
git commit -m "Address feedback"
git push origin feature/new-feature
# Update PR with latest main
git checkout main
git pull origin main
git checkout feature/new-feature
git rebase main
git push --force-with-lease origin feature/new-feature
# Request review
gh pr review <pr-number> --request-changes --body "Please fix..."
gh pr review <pr-number> --approve --body "LGTM!"
Reviewing Pull Requests
# Check out PR locally
gh pr checkout <pr-number>
# Or manually:
git fetch origin pull/<pr-number>/head:pr-<pr-number>
git checkout pr-<pr-number>
# Test changes
npm test
npm run build
# Add review comments
gh pr review <pr-number> --comment --body "Looks good!"
# Approve PR
gh pr review <pr-number> --approve
# Request changes
gh pr review <pr-number> --request-changes --body "Please address..."
Merging Pull Requests
# Merge via GitHub CLI
gh pr merge <pr-number> --merge
gh pr merge <pr-number> --squash
gh pr merge <pr-number> --rebase
# Via web interface:
# - Merge commit: Preserves all commits
# - Squash and merge: Combines all commits into one
# - Rebase and merge: Adds commits to base branch
# After merging, cleanup
git checkout main
git pull origin main
git branch -d feature/new-feature
GitHub Issues
Creating Issues
# Create issue via CLI
gh issue create --title "Bug: Login fails" --body "Description of bug"
# Create with labels
gh issue create --title "Feature request" --label "enhancement"
# List issues
gh issue list
gh issue list --label "bug"
gh issue list --assignee "@me"
# View issue
gh issue view <issue-number>
Working with Issues
# Assign issue
gh issue edit <issue-number> --add-assignee "@me"
# Add labels
gh issue edit <issue-number> --add-label "bug,high-priority"
# Close issue
gh issue close <issue-number>
# Reopen issue
gh issue reopen <issue-number>
# Link PR to issue (in commit or PR description)
git commit -m "Fix login bug
Fixes #123"
# Or "Closes #123", "Resolves #123"
Issue Templates
Create .github/ISSUE_TEMPLATE/bug_report.md:
---
name: Bug Report
about: Create a report to help us improve
title: '[BUG] '
labels: bug
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. See error
**Expected behavior**
A clear description of what you expected to happen.
**Screenshots**
If applicable, add screenshots.
**Environment:**
- OS: [e.g. Ubuntu 22.04]
- Browser: [e.g. Chrome 120]
- Version: [e.g. v1.2.3]
GitHub Actions (CI/CD)
Basic Workflow
Create .github/workflows/ci.yml:
name: CI
on:
push:
branches: [ main, develop ]
pull_request:
branches: [ main ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Setup Node.js
uses: actions/setup-node@v3
with:
node-version: '18'
- name: Install dependencies
run: npm ci
- name: Run tests
run: npm test
- name: Run linter
run: npm run lint
- name: Build
run: npm run build
Action Permissions
# Enable workflow permissions
# Repository Settings → Actions → General → Workflow permissions
# Select: Read and write permissions
# Or in workflow file
permissions:
contents: write
pull-requests: write
issues: write
Deployment Workflow
Create .github/workflows/deploy.yml:
name: Deploy
on:
push:
branches: [ main ]
release:
types: [ published ]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Build
run: npm run build
- name: Deploy to GitHub Pages
uses: peaceiris/actions-gh-pages@v3
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: ./build
cname: yourdomain.com
Useful GitHub Actions
# Test on multiple Node versions
strategy:
matrix:
node-version: [16, 18, 20]
# Cache dependencies
- name: Cache node modules
uses: actions/cache@v3
with:
path: ~/.npm
key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }}
# Create release
- name: Create Release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ github.ref }}
release_name: Release ${{ github.ref }}
# Comment on PR
- name: Comment on PR
uses: actions/github-script@v6
with:
script: |
github.rest.issues.createComment({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
body: 'Build successful! ✅'
})
GitHub Pages
Setup GitHub Pages
# 1. Create gh-pages branch
git checkout --orphan gh-pages
git rm -rf .
echo "<!DOCTYPE html><html><body><h1>Hello World</h1></body></html>" > index.html
git add index.html
git commit -m "Initial GitHub Pages commit"
git push origin gh-pages
# 2. Enable in repository settings
# Settings → Pages → Source → Select branch: gh-pages, folder: /(root)
# 3. Add custom domain (optional)
echo "yourdomain.com" > CNAME
git add CNAME
git commit -m "Add custom domain"
git push origin gh-pages
Deploy with Actions
name: Deploy to GitHub Pages
on:
push:
branches: [ main ]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Setup Node.js
uses: actions/setup-node@v3
with:
node-version: '18'
- name: Build
run: |
npm ci
npm run build
- name: Deploy
uses: peaceiris/actions-gh-pages@v3
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: ./build
cname: yourdomain.com # Optional
GitHub CLI (gh)
Installation
# macOS
brew install gh
# Linux (Debian/Ubuntu)
curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg | sudo dd of=/usr/share/keyrings/githubcli-archive-keyring.gpg
echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | sudo tee /etc/apt/sources.list.d/github-cli.list > /dev/null
sudo apt update
sudo apt install gh
# Authenticate
gh auth login
Common gh Commands
# Repository
gh repo create <name>
gh repo clone <repo>
gh repo view
gh repo fork
# Pull Requests
gh pr create
gh pr list
gh pr view <number>
gh pr checkout <number>
gh pr merge <number>
gh pr diff <number>
gh pr review <number>
# Issues
gh issue create
gh issue list
gh issue view <number>
gh issue close <number>
# Releases
gh release create v1.0.0
gh release list
gh release download v1.0.0
# Workflows
gh workflow list
gh workflow run <workflow>
gh run list
gh run view <run-id>
# Gists
gh gist create <file>
gh gist list
Collaboration Workflows
Fork and Contribute
# 1. Fork repository on GitHub (click Fork button)
# 2. Clone your fork
gh repo fork <original-repo> --clone
# 3. Add upstream remote
git remote add upstream https://github.com/original-owner/repo.git
# 4. Create feature branch
git checkout -b feature/my-contribution
# 5. Make changes and commit
git add .
git commit -m "Add feature"
# 6. Keep fork updated
git fetch upstream
git checkout main
git merge upstream/main
git push origin main
# 7. Push feature branch
git push origin feature/my-contribution
# 8. Create pull request
gh pr create --base main --head feature/my-contribution
Code Review Best Practices
# As PR author:
# - Keep PRs small and focused
# - Write clear description
# - Link related issues
# - Respond to feedback promptly
# As reviewer:
# - Review promptly
# - Be constructive and specific
# - Test changes locally
# - Approve or request changes
# Request specific reviewers
gh pr create --reviewer @username1,@username2
# Check PR status
gh pr status
# View PR checks
gh pr checks <pr-number>
Team Collaboration
# Protect branches
# Repository Settings → Branches → Add branch protection rule
# - Require pull request reviews
# - Require status checks to pass
# - Require branches to be up to date
# - Include administrators
# Add collaborators
# Repository Settings → Collaborators → Add people
# Use code owners (.github/CODEOWNERS)
# Require approval from code owners
* @team-name
/docs/ @docs-team
*.js @frontend-team
Project Management
GitHub Projects
# Create project
gh project create --title "My Project"
# Add issues to project
gh issue create --project "My Project"
# View project
gh project view <project-number>
Milestones
# Create milestone
# Issues → Milestones → New milestone
# Assign issue to milestone
gh issue edit <number> --milestone "v1.0"
# View milestone progress
# Issues → Milestones
Security Features
Dependabot
Enable in Settings → Security → Dependabot:
- Dependabot alerts
- Dependabot security updates
- Dependabot version updates
Create .github/dependabot.yml:
version: 2
updates:
- package-ecosystem: "npm"
directory: "/"
schedule:
interval: "weekly"
open-pull-requests-limit: 10
Secret Scanning
# Enable in Settings → Security → Code security and analysis
# - Secret scanning
# - Secret scanning push protection
# Use secrets in workflows
steps:
- name: Deploy
env:
API_KEY: ${{ secrets.API_KEY }}
run: deploy.sh
Security Advisories
# Create security advisory
# Security → Advisories → New draft security advisory
# Report vulnerability privately
# Contact repository maintainers through security tab
Webhooks and API
Setup Webhook
# Repository Settings → Webhooks → Add webhook
# Payload URL: https://your-server.com/webhook
# Content type: application/json
# Events: Push, Pull request, Issues, etc.
GitHub API
# Using curl
curl -H "Authorization: token YOUR_TOKEN" \
https://api.github.com/user/repos
# Using gh CLI with API
gh api repos/<owner>/<repo>/pulls
gh api graphql -f query='
query {
repository(owner: "owner", name: "repo") {
pullRequests(first: 10) {
nodes {
title
number
}
}
}
}
'
Advanced Features
GitHub Codespaces
# Create codespace
gh codespace create --repo <repo>
# List codespaces
gh codespace list
# Connect to codespace
gh codespace ssh
GitHub Packages
Publish package:
- name: Publish to GitHub Packages
run: npm publish
env:
NODE_AUTH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GitHub Discussions
# Enable in Settings → General → Features → Discussions
# Create discussion
gh api repos/<owner>/<repo>/discussions \
-f title="Discussion title" \
-f body="Discussion body"
Best Practices
- Branch Protection: Enable branch protection on main/develop
- Required Reviews: Require at least one approval before merging
- Status Checks: Require CI/CD to pass before merging
- Linear History: Use squash or rebase merging for clean history
- Signed Commits: Enable commit signing for security
- Templates: Use PR and issue templates for consistency
- Labels: Use labels to categorize issues and PRs
- Milestones: Track progress with milestones
- Projects: Use GitHub Projects for project management
- Documentation: Keep README and CONTRIBUTING.md updated
Programming Languages
This section contains references and guides for various programming languages.
Available Languages
- Python - A high-level, interpreted programming language
- C - A general-purpose, procedural programming language
- C++ - An extension of C with object-oriented features
- JavaScript - A scripting language for web development
- TypeScript - A strongly-typed superset of JavaScript for large-scale applications
- Bash - A Unix shell and command language
- Java - A class-based, object-oriented programming language
- Go - A statically typed, compiled language with built-in concurrency
- Lua - A lightweight, embeddable scripting language
- Rust - A systems programming language focused on safety and performance
- SQL - A domain-specific language for managing databases
Additional Topics
- Design Patterns - Common software design patterns and best practices
- Interview Questions - Common programming interview questions and solutions
Python Programming
Overview
Python is a high-level, interpreted, dynamically-typed programming language known for its simplicity and readability. It's widely used for web development, data science, machine learning, automation, and scripting.
Key Features:
- Clean, readable syntax emphasizing indentation
- Dynamic typing with strong type checking
- Extensive standard library ("batteries included")
- Large ecosystem of third-party packages (PyPI)
- Multi-paradigm: procedural, object-oriented, functional
Basic Syntax
Variables and Data Types
# Variables (no declaration needed)
x = 10 # int
y = 3.14 # float
name = "Alice" # str
is_valid = True # bool
# Type checking and conversion
print(type(x)) # <class 'int'>
num_str = str(42) # Convert to string
num_int = int("42") # Convert to int
Print and String Formatting
# Basic print
print("Hello, World!")
# f-strings (Python 3.6+)
name = "Bob"
age = 30
print(f"{name} is {age} years old")
# .format() method
print("{} is {} years old".format(name, age))
# %-formatting (older style)
print("%s is %d years old" % (name, age))
Data Structures
Lists
Lists are mutable, ordered sequences that can contain mixed types.
# Creating lists
my_list = [1, 2, 3, 4, 5]
mixed = [1, "hello", 3.14, True]
empty = []
# Common operations
my_list.append(6) # Add to end: [1, 2, 3, 4, 5, 6]
my_list.insert(0, 0) # Insert at index: [0, 1, 2, 3, 4, 5, 6]
my_list.pop() # Remove and return last: 6
my_list.remove(3) # Remove first occurrence of value
element = my_list[2] # Access by index
my_list[1] = 10 # Modify by index
# Slicing
first_three = my_list[0:3] # [0, 1, 2]
last_two = my_list[-2:] # Last 2 elements
reversed_list = my_list[::-1] # Reverse
# List comprehensions
squares = [x**2 for x in range(10)]
evens = [x for x in range(20) if x % 2 == 0]
# Common methods
len(my_list) # Length
my_list.sort() # Sort in-place
sorted(my_list) # Return sorted copy
my_list.reverse() # Reverse in-place
my_list.count(2) # Count occurrences
my_list.index(2) # Find first index
List Characteristics:
- Mutable (can change)
- Ordered (maintains insertion order)
- Allows duplicates
- Can be nested
- Dynamic sizing
Tuples
Tuples are immutable, ordered sequences.
# Creating tuples
my_tuple = (1, 2, 3, 4, 5)
single = (42,) # Single element needs comma
empty = ()
# Accessing elements
first = my_tuple[0]
last = my_tuple[-1]
sub = my_tuple[1:3]
# Unpacking
x, y, z = (1, 2, 3)
a, *rest, b = (1, 2, 3, 4, 5) # a=1, rest=[2,3,4], b=5
# Named tuples
from collections import namedtuple
Point = namedtuple('Point', ['x', 'y'])
p = Point(10, 20)
print(p.x, p.y) # 10 20
Tuple Characteristics:
- Immutable (cannot change)
- Ordered
- Faster than lists
- Can be used as dictionary keys
- Used for function return values
Dictionaries
Dictionaries are mutable, unordered key-value collections (ordered in Python 3.7+).
# Creating dictionaries
person = {
"name": "Alice",
"age": 30,
"city": "NYC"
}
empty = {}
from_keys = dict.fromkeys(['a', 'b', 'c'], 0) # {'a': 0, 'b': 0, 'c': 0}
# Accessing and modifying
name = person["name"] # KeyError if not exists
age = person.get("age", 0) # Returns default if not exists
person["email"] = "alice@example.com" # Add/update
del person["city"] # Delete key
# Methods
person.keys() # dict_keys(['name', 'age', 'email'])
person.values() # dict_values(['Alice', 30, 'alice@example.com'])
person.items() # dict_items([('name', 'Alice'), ...])
# Iteration
for key in person:
print(key, person[key])
for key, value in person.items():
print(f"{key}: {value}")
# Dictionary comprehension
squares = {x: x**2 for x in range(5)}
# {0: 0, 1: 1, 2: 4, 3: 9, 4: 16}
# Merge dictionaries (Python 3.9+)
dict1 = {"a": 1, "b": 2}
dict2 = {"c": 3, "d": 4}
merged = dict1 | dict2
Sets
Sets are mutable, unordered collections of unique elements.
# Creating sets
my_set = {1, 2, 3, 4, 5}
empty = set() # Note: {} creates empty dict
from_list = set([1, 2, 2, 3, 3, 3]) # {1, 2, 3}
# Operations
my_set.add(6)
my_set.remove(3) # KeyError if not exists
my_set.discard(3) # No error if not exists
my_set.pop() # Remove and return arbitrary element
# Set operations
a = {1, 2, 3, 4}
b = {3, 4, 5, 6}
union = a | b # {1, 2, 3, 4, 5, 6}
intersection = a & b # {3, 4}
difference = a - b # {1, 2}
symmetric_diff = a ^ b # {1, 2, 5, 6}
# Set comprehension
evens = {x for x in range(10) if x % 2 == 0}
Control Flow
If-Elif-Else
age = 18
if age < 13:
print("Child")
elif age < 20:
print("Teenager")
else:
print("Adult")
# Ternary operator
status = "Adult" if age >= 18 else "Minor"
# Check multiple conditions
if 10 < age < 20:
print("Teenager")
Loops
# For loop
for i in range(5):
print(i) # 0, 1, 2, 3, 4
for i in range(0, 10, 2): # Start, stop, step
print(i) # 0, 2, 4, 6, 8
# Iterate over list
fruits = ["apple", "banana", "cherry"]
for fruit in fruits:
print(fruit)
# Enumerate (get index and value)
for idx, fruit in enumerate(fruits):
print(f"{idx}: {fruit}")
# While loop
count = 0
while count < 5:
print(count)
count += 1
# Break and continue
for i in range(10):
if i == 3:
continue # Skip 3
if i == 7:
break # Stop at 7
print(i)
# Else clause (runs if loop completes without break)
for i in range(5):
print(i)
else:
print("Loop completed")
Functions
Basic Functions
# Simple function
def greet(name):
return f"Hello, {name}!"
# Default arguments
def greet(name="World"):
return f"Hello, {name}!"
# Multiple return values
def get_stats(numbers):
return min(numbers), max(numbers), sum(numbers)/len(numbers)
minimum, maximum, average = get_stats([1, 2, 3, 4, 5])
# *args (variable positional arguments)
def sum_all(*args):
return sum(args)
print(sum_all(1, 2, 3, 4, 5)) # 15
# **kwargs (variable keyword arguments)
def print_info(**kwargs):
for key, value in kwargs.items():
print(f"{key}: {value}")
print_info(name="Alice", age=30, city="NYC")
# Lambda functions
square = lambda x: x**2
add = lambda x, y: x + y
# Map, Filter, Reduce
numbers = [1, 2, 3, 4, 5]
squared = list(map(lambda x: x**2, numbers))
evens = list(filter(lambda x: x % 2 == 0, numbers))
from functools import reduce
product = reduce(lambda x, y: x * y, numbers) # 120
Decorators
# Simple decorator
def my_decorator(func):
def wrapper(*args, **kwargs):
print("Before function call")
result = func(*args, **kwargs)
print("After function call")
return result
return wrapper
@my_decorator
def say_hello():
print("Hello!")
say_hello()
# Output:
# Before function call
# Hello!
# After function call
# Decorator with arguments
def repeat(times):
def decorator(func):
def wrapper(*args, **kwargs):
for _ in range(times):
result = func(*args, **kwargs)
return result
return wrapper
return decorator
@repeat(3)
def greet(name):
print(f"Hello, {name}!")
greet("Alice") # Prints 3 times
# Common built-in decorators
class MyClass:
@staticmethod
def static_method():
print("Static method")
@classmethod
def class_method(cls):
print(f"Class method of {cls}")
@property
def value(self):
return self._value
@value.setter
def value(self, val):
self._value = val
Object-Oriented Programming
Classes and Objects
class Person:
# Class variable (shared by all instances)
species = "Homo sapiens"
def __init__(self, name, age):
# Instance variables
self.name = name
self.age = age
# Instance method
def greet(self):
return f"Hello, I'm {self.name} and I'm {self.age} years old"
# Magic methods
def __str__(self):
return f"Person({self.name}, {self.age})"
def __repr__(self):
return f"Person('{self.name}', {self.age})"
def __eq__(self, other):
return self.name == other.name and self.age == other.age
# Creating objects
person1 = Person("Alice", 30)
person2 = Person("Bob", 25)
print(person1.greet())
print(str(person1))
Inheritance
# Single inheritance
class Animal:
def __init__(self, name):
self.name = name
def speak(self):
pass
class Dog(Animal):
def speak(self):
return f"{self.name} says Woof!"
class Cat(Animal):
def speak(self):
return f"{self.name} says Meow!"
dog = Dog("Buddy")
cat = Cat("Whiskers")
print(dog.speak()) # Buddy says Woof!
# Multiple inheritance
class Flyer:
def fly(self):
return "Flying..."
class Swimmer:
def swim(self):
return "Swimming..."
class Duck(Animal, Flyer, Swimmer):
def speak(self):
return f"{self.name} says Quack!"
duck = Duck("Donald")
print(duck.speak()) # Donald says Quack!
print(duck.fly()) # Flying...
print(duck.swim()) # Swimming...
# super() for parent class
class Employee(Person):
def __init__(self, name, age, employee_id):
super().__init__(name, age)
self.employee_id = employee_id
def get_info(self):
return f"{self.greet()}, ID: {self.employee_id}"
Data Classes (Python 3.7+)
from dataclasses import dataclass, field
from typing import List
@dataclass
class Person:
name: str
age: int
email: str = "unknown@example.com" # Default value
hobbies: List[str] = field(default_factory=list)
def greet(self):
return f"Hello, I'm {self.name}"
person = Person("Alice", 30)
print(person) # Person(name='Alice', age=30, email='unknown@example.com', hobbies=[])
# Frozen (immutable) dataclass
@dataclass(frozen=True)
class Point:
x: int
y: int
Common Patterns
Singleton Pattern
class Singleton:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
# All instances are the same
s1 = Singleton()
s2 = Singleton()
print(s1 is s2) # True
Factory Pattern
class Dog:
def speak(self):
return "Woof!"
class Cat:
def speak(self):
return "Meow!"
class AnimalFactory:
@staticmethod
def create_animal(animal_type):
if animal_type == "dog":
return Dog()
elif animal_type == "cat":
return Cat()
else:
raise ValueError(f"Unknown animal type: {animal_type}")
# Usage
animal = AnimalFactory.create_animal("dog")
print(animal.speak()) # Woof!
Context Manager Pattern
# Custom context manager
class FileManager:
def __init__(self, filename, mode):
self.filename = filename
self.mode = mode
self.file = None
def __enter__(self):
self.file = open(self.filename, self.mode)
return self.file
def __exit__(self, exc_type, exc_val, exc_tb):
if self.file:
self.file.close()
# Usage
with FileManager('test.txt', 'w') as f:
f.write('Hello, World!')
# Using contextlib
from contextlib import contextmanager
@contextmanager
def file_manager(filename, mode):
f = open(filename, mode)
try:
yield f
finally:
f.close()
with file_manager('test.txt', 'r') as f:
content = f.read()
Iterator and Generator Patterns
# Iterator
class Counter:
def __init__(self, start, end):
self.current = start
self.end = end
def __iter__(self):
return self
def __next__(self):
if self.current > self.end:
raise StopIteration
current = self.current
self.current += 1
return current
for num in Counter(1, 5):
print(num) # 1, 2, 3, 4, 5
# Generator
def counter(start, end):
while start <= end:
yield start
start += 1
for num in counter(1, 5):
print(num)
# Generator expressions
squares = (x**2 for x in range(10))
print(next(squares)) # 0
print(next(squares)) # 1
File Handling
# Reading files
with open('file.txt', 'r', encoding='utf-8') as f:
content = f.read() # Read entire file
with open('file.txt', 'r', encoding='utf-8') as f:
lines = f.readlines() # Read all lines as list
with open('file.txt', 'r', encoding='utf-8') as f:
for line in f: # Iterate line by line
print(line.strip())
# Writing files
with open('file.txt', 'w', encoding='utf-8') as f:
f.write('Hello, World!\n')
# Appending
with open('file.txt', 'a', encoding='utf-8') as f:
f.write('New line\n')
# Binary files
with open('image.png', 'rb') as f:
data = f.read()
# JSON files
import json
# Write JSON
data = {"name": "Alice", "age": 30}
with open('data.json', 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2)
# Read JSON
with open('data.json', 'r', encoding='utf-8') as f:
data = json.load(f)
# CSV files
import csv
# Write CSV
with open('data.csv', 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow(['Name', 'Age', 'City'])
writer.writerow(['Alice', 30, 'NYC'])
# Read CSV
with open('data.csv', 'r', encoding='utf-8') as f:
reader = csv.reader(f)
for row in reader:
print(row)
Error Handling
# Try-except
try:
result = 10 / 0
except ZeroDivisionError:
print("Cannot divide by zero!")
# Multiple exceptions
try:
value = int("abc")
except (ValueError, TypeError) as e:
print(f"Error: {e}")
# Catch all exceptions
try:
risky_operation()
except Exception as e:
print(f"An error occurred: {e}")
# Finally block
try:
f = open('file.txt', 'r')
content = f.read()
except FileNotFoundError:
print("File not found")
finally:
f.close() # Always executes
# Else block
try:
result = 10 / 2
except ZeroDivisionError:
print("Error")
else:
print("Success!") # Runs if no exception
# Raising exceptions
def validate_age(age):
if age < 0:
raise ValueError("Age cannot be negative")
return age
# Custom exceptions
class InvalidAgeError(Exception):
pass
def check_age(age):
if age < 0:
raise InvalidAgeError("Age must be positive")
Working with Excel and Pandas
from dataclasses import dataclass
import pandas as pd
from typing import List
@dataclass
class Person:
name: str
age: int
email: str
# Load data from Excel
def load_people_from_excel(file_path: str) -> List[Person]:
df = pd.read_excel(file_path)
return [
Person(
name=row['name'],
age=row['age'],
email=row['email']
) for _, row in df.iterrows()
]
# Usage
people = load_people_from_excel("data.xlsx")
for person in people:
print(f"{person.name} is {person.age} years old")
# With column mapping
EXCEL_TO_CLASS_MAPPING = {
'Full Name': 'name',
'Person Age': 'age',
'E-mail Address': 'email'
}
def load_with_mapping(file_path: str) -> List[Person]:
df = pd.read_excel(file_path)
df = df.rename(columns=EXCEL_TO_CLASS_MAPPING)
return [Person(**row) for _, row in df.iterrows()]
Common Python Idioms
# List comprehension vs map/filter
numbers = range(10)
squares = [x**2 for x in numbers]
evens = [x for x in numbers if x % 2 == 0]
# Dictionary get with default
config = {"debug": True}
log_level = config.get("log_level", "INFO")
# String joining
words = ["Hello", "World"]
sentence = " ".join(words)
# Enumerate
for idx, value in enumerate(['a', 'b', 'c']):
print(f"{idx}: {value}")
# Zip
names = ["Alice", "Bob", "Charlie"]
ages = [25, 30, 35]
for name, age in zip(names, ages):
print(f"{name} is {age} years old")
# Unpacking
first, *middle, last = [1, 2, 3, 4, 5]
# Swapping variables
a, b = 10, 20
a, b = b, a
# Chaining comparisons
if 0 < x < 10:
print("x is between 0 and 10")
# In-place operations
numbers = [1, 2, 3]
numbers += [4, 5] # Extend list
# any() and all()
numbers = [2, 4, 6, 8]
all_even = all(x % 2 == 0 for x in numbers)
has_even = any(x % 2 == 0 for x in numbers)
Virtual Environments
# Create virtual environment
python -m venv venv
# Activate (Linux/Mac)
source venv/bin/activate
# Activate (Windows)
venv\Scripts\activate
# Install packages
pip install requests pandas numpy
# Save dependencies
pip freeze > requirements.txt
# Install from requirements
pip install -r requirements.txt
# Deactivate
deactivate
Best Practices
-
Follow PEP 8: Python's style guide
- Use 4 spaces for indentation
- Max line length: 79 characters
- Use snake_case for functions and variables
- Use PascalCase for classes
-
Use Type Hints (Python 3.5+)
def greet(name: str) -> str: return f"Hello, {name}!" from typing import List, Dict, Optional def process_data(items: List[int]) -> Dict[str, int]: return {"sum": sum(items), "count": len(items)} -
Use List Comprehensions for simple transformations
# Good squares = [x**2 for x in range(10)] # Avoid for complex logic # Use regular loops instead -
Use Context Managers for resource management
with open('file.txt', 'r') as f: data = f.read() -
Use f-strings for string formatting (Python 3.6+)
name = "Alice" age = 30 print(f"{name} is {age} years old")
Common Libraries
- Requests: HTTP requests
- NumPy: Numerical computing
- Pandas: Data analysis
- Matplotlib/Seaborn: Data visualization
- Flask/Django: Web frameworks
- SQLAlchemy: Database ORM
- pytest: Testing
- Beautiful Soup: Web scraping
- Pillow: Image processing
C Programming
Overview
C is a general-purpose, procedural programming language developed by Dennis Ritchie at Bell Labs in 1972. It's widely used for system programming, embedded systems, operating systems (Unix/Linux), and applications requiring high performance and low-level memory access.
Key Features:
- Low-level access to memory via pointers
- Efficient execution with minimal runtime overhead
- Portable across different platforms
- Rich library of functions
- Structured programming with functions and modular code
- Static typing with compile-time type checking
Basic Syntax
Program Structure
#include <stdio.h> // Preprocessor directive
#include <stdlib.h>
// Function prototype
int add(int a, int b);
// Main function - entry point
int main(void) {
printf("Hello, World!\n");
int result = add(5, 3);
printf("5 + 3 = %d\n", result);
return 0; // Return success code
}
// Function definition
int add(int a, int b) {
return a + b;
}
Comments
// Single-line comment
/*
* Multi-line comment
* Spans multiple lines
*/
Data Types
Primitive Data Types
// Integer types
char c = 'A'; // 1 byte: -128 to 127
unsigned char uc = 255; // 1 byte: 0 to 255
short s = 32000; // 2 bytes: -32,768 to 32,767
unsigned short us = 65000; // 2 bytes: 0 to 65,535
int i = 100000; // 4 bytes: -2,147,483,648 to 2,147,483,647
unsigned int ui = 400000; // 4 bytes: 0 to 4,294,967,295
long l = 1000000L; // 4 or 8 bytes (platform-dependent)
unsigned long ul = 2000000UL;
long long ll = 9223372036854775807LL; // 8 bytes
// Floating-point types
float f = 3.14f; // 4 bytes, ~7 decimal digits precision
double d = 3.14159265359; // 8 bytes, ~15 decimal digits precision
long double ld = 3.14159265358979323846L; // 10-16 bytes
// Boolean (C99 and later)
#include <stdbool.h>
bool flag = true; // true or false
Size of Data Types
#include <stdio.h>
int main(void) {
printf("Size of char: %zu bytes\n", sizeof(char));
printf("Size of int: %zu bytes\n", sizeof(int));
printf("Size of float: %zu bytes\n", sizeof(float));
printf("Size of double: %zu bytes\n", sizeof(double));
printf("Size of pointer: %zu bytes\n", sizeof(void*));
return 0;
}
Variables and Constants
Variable Declaration
int x; // Declaration
int y = 10; // Declaration with initialization
int a, b, c; // Multiple declarations
int m = 5, n = 10; // Multiple with initialization
// Variable naming rules:
// - Must start with letter or underscore
// - Can contain letters, digits, underscores
// - Case-sensitive
// - Cannot use reserved keywords
Constants
// Using const keyword
const int MAX_SIZE = 100;
const double PI = 3.14159;
// Using #define preprocessor
#define BUFFER_SIZE 1024
#define TRUE 1
#define FALSE 0
// Enumeration constants
enum Color {
RED, // 0
GREEN, // 1
BLUE // 2
};
enum Status {
SUCCESS = 0,
ERROR = -1,
PENDING = 1
};
Operators
Arithmetic Operators
int a = 10, b = 3;
int sum = a + b; // Addition: 13
int diff = a - b; // Subtraction: 7
int prod = a * b; // Multiplication: 30
int quot = a / b; // Division: 3 (integer division)
int rem = a % b; // Modulus: 1
// Increment/Decrement
int x = 5;
x++; // Post-increment: x = 6
++x; // Pre-increment: x = 7
x--; // Post-decrement: x = 6
--x; // Pre-decrement: x = 5
Relational Operators
int a = 5, b = 10;
int result;
result = (a == b); // Equal to: 0 (false)
result = (a != b); // Not equal: 1 (true)
result = (a > b); // Greater than: 0
result = (a < b); // Less than: 1
result = (a >= b); // Greater than or equal: 0
result = (a <= b); // Less than or equal: 1
Logical Operators
int a = 1, b = 0;
int and_result = a && b; // Logical AND: 0
int or_result = a || b; // Logical OR: 1
int not_result = !a; // Logical NOT: 0
Bitwise Operators
unsigned int a = 5; // 0101 in binary
unsigned int b = 3; // 0011 in binary
unsigned int and = a & b; // AND: 0001 (1)
unsigned int or = a | b; // OR: 0111 (7)
unsigned int xor = a ^ b; // XOR: 0110 (6)
unsigned int not = ~a; // NOT: 1010 (complement)
unsigned int left = a << 1; // Left shift: 1010 (10)
unsigned int right = a >> 1;// Right shift: 0010 (2)
Assignment Operators
int x = 10;
x += 5; // x = x + 5; (15)
x -= 3; // x = x - 3; (12)
x *= 2; // x = x * 2; (24)
x /= 4; // x = x / 4; (6)
x %= 5; // x = x % 5; (1)
x &= 3; // x = x & 3;
x |= 2; // x = x | 2;
x ^= 1; // x = x ^ 1;
x <<= 1; // x = x << 1;
x >>= 1; // x = x >> 1;
Ternary Operator
int a = 10, b = 20;
int max = (a > b) ? a : b; // max = 20
// Equivalent to:
int max;
if (a > b) {
max = a;
} else {
max = b;
}
Control Flow
if-else Statements
int age = 18;
if (age >= 18) {
printf("Adult\n");
} else {
printf("Minor\n");
}
// if-else if-else
int score = 85;
if (score >= 90) {
printf("Grade: A\n");
} else if (score >= 80) {
printf("Grade: B\n");
} else if (score >= 70) {
printf("Grade: C\n");
} else {
printf("Grade: F\n");
}
// Nested if
int x = 10, y = 20;
if (x > 0) {
if (y > 0) {
printf("Both positive\n");
}
}
switch Statement
int day = 3;
switch (day) {
case 1:
printf("Monday\n");
break;
case 2:
printf("Tuesday\n");
break;
case 3:
printf("Wednesday\n");
break;
case 4:
printf("Thursday\n");
break;
case 5:
printf("Friday\n");
break;
case 6:
case 7:
printf("Weekend\n");
break;
default:
printf("Invalid day\n");
break;
}
// Switch with fall-through
char grade = 'B';
switch (grade) {
case 'A':
case 'B':
case 'C':
printf("Pass\n");
break;
case 'D':
case 'F':
printf("Fail\n");
break;
default:
printf("Invalid grade\n");
}
Loops
for Loop
// Basic for loop
for (int i = 0; i < 10; i++) {
printf("%d ", i);
}
// Nested for loop
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
printf("(%d, %d) ", i, j);
}
printf("\n");
}
// Multiple expressions
for (int i = 0, j = 10; i < 10; i++, j--) {
printf("i=%d, j=%d\n", i, j);
}
// Infinite loop
for (;;) {
// Loop forever (use break to exit)
break;
}
while Loop
int count = 0;
while (count < 5) {
printf("%d ", count);
count++;
}
// Reading input until condition
int num;
printf("Enter positive numbers (0 to stop): ");
while (scanf("%d", &num) == 1 && num != 0) {
printf("You entered: %d\n", num);
}
// Infinite loop
while (1) {
// Loop forever
break;
}
do-while Loop
int num;
do {
printf("Enter a positive number: ");
scanf("%d", &num);
} while (num <= 0);
// Executes at least once
int x = 10;
do {
printf("x = %d\n", x);
x++;
} while (x < 5); // Condition false, but body executes once
Loop Control Statements
// break - exits the loop
for (int i = 0; i < 10; i++) {
if (i == 5) break;
printf("%d ", i); // Prints: 0 1 2 3 4
}
// continue - skips to next iteration
for (int i = 0; i < 10; i++) {
if (i % 2 == 0) continue;
printf("%d ", i); // Prints: 1 3 5 7 9
}
// goto - jumps to a label (use sparingly)
int i = 0;
start:
printf("%d ", i);
i++;
if (i < 5) goto start;
Functions
Function Declaration and Definition
// Function prototype (declaration)
int add(int a, int b);
void greet(void);
double calculate(int x, double y);
// Function definition
int add(int a, int b) {
return a + b;
}
void greet(void) {
printf("Hello!\n");
// No return statement needed for void
}
double calculate(int x, double y) {
return x * y;
}
Function Parameters
// Pass by value
void increment(int x) {
x++; // Only affects local copy
}
int main(void) {
int num = 5;
increment(num);
printf("%d\n", num); // Still 5
return 0;
}
// Pass by reference (using pointers)
void increment_ref(int *x) {
(*x)++; // Modifies original value
}
int main(void) {
int num = 5;
increment_ref(&num);
printf("%d\n", num); // Now 6
return 0;
}
Return Values
// Return single value
int square(int x) {
return x * x;
}
// Return multiple values via pointers
void divide(int a, int b, int *quotient, int *remainder) {
*quotient = a / b;
*remainder = a % b;
}
int main(void) {
int q, r;
divide(10, 3, &q, &r);
printf("10 / 3 = %d remainder %d\n", q, r);
return 0;
}
Variadic Functions
#include <stdarg.h>
// Function with variable number of arguments
int sum(int count, ...) {
va_list args;
va_start(args, count);
int total = 0;
for (int i = 0; i < count; i++) {
total += va_arg(args, int);
}
va_end(args);
return total;
}
int main(void) {
printf("Sum: %d\n", sum(3, 10, 20, 30)); // 60
printf("Sum: %d\n", sum(5, 1, 2, 3, 4, 5)); // 15
return 0;
}
Recursive Functions
// Factorial
int factorial(int n) {
if (n <= 1) return 1;
return n * factorial(n - 1);
}
// Fibonacci
int fibonacci(int n) {
if (n <= 1) return n;
return fibonacci(n - 1) + fibonacci(n - 2);
}
// Binary search (recursive)
int binary_search(int arr[], int left, int right, int target) {
if (left > right) return -1;
int mid = left + (right - left) / 2;
if (arr[mid] == target) return mid;
if (arr[mid] > target) return binary_search(arr, left, mid - 1, target);
return binary_search(arr, mid + 1, right, target);
}
Arrays
Array Declaration and Initialization
// Declaration
int numbers[5];
// Declaration with initialization
int primes[5] = {2, 3, 5, 7, 11};
// Partial initialization (rest are 0)
int values[10] = {1, 2, 3}; // {1, 2, 3, 0, 0, 0, 0, 0, 0, 0}
// Size inferred from initializer
int data[] = {10, 20, 30, 40}; // Size: 4
// Zero-initialize all elements
int zeros[100] = {0};
Accessing Array Elements
int arr[5] = {10, 20, 30, 40, 50};
// Access elements
int first = arr[0]; // 10
int third = arr[2]; // 30
// Modify elements
arr[1] = 25; // arr is now {10, 25, 30, 40, 50}
// Loop through array
for (int i = 0; i < 5; i++) {
printf("%d ", arr[i]);
}
Multi-dimensional Arrays
// 2D array
int matrix[3][4] = {
{1, 2, 3, 4},
{5, 6, 7, 8},
{9, 10, 11, 12}
};
// Access elements
int value = matrix[1][2]; // 7
// Loop through 2D array
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 4; j++) {
printf("%d ", matrix[i][j]);
}
printf("\n");
}
// 3D array
int cube[2][3][4];
Arrays and Pointers
int arr[5] = {10, 20, 30, 40, 50};
// Array name is a pointer to first element
int *ptr = arr; // Same as &arr[0]
// Pointer arithmetic
printf("%d\n", *ptr); // 10
printf("%d\n", *(ptr + 1)); // 20
printf("%d\n", *(ptr + 2)); // 30
// Equivalent notations
arr[2] == *(arr + 2) == *(ptr + 2) == ptr[2] // All equal 30
Passing Arrays to Functions
// Array passed as pointer
void print_array(int arr[], int size) {
for (int i = 0; i < size; i++) {
printf("%d ", arr[i]);
}
printf("\n");
}
// Equivalent declaration
void print_array(int *arr, int size) {
// Same as above
}
// 2D array
void print_matrix(int rows, int cols, int matrix[rows][cols]) {
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
printf("%d ", matrix[i][j]);
}
printf("\n");
}
}
Pointers
Pointer Basics
int x = 10;
int *ptr = &x; // ptr stores address of x
printf("Value of x: %d\n", x); // 10
printf("Address of x: %p\n", (void*)&x); // Memory address
printf("Value of ptr: %p\n", (void*)ptr);// Same as &x
printf("Value at ptr: %d\n", *ptr); // 10 (dereference)
// Modify through pointer
*ptr = 20;
printf("New value of x: %d\n", x); // 20
Pointer Arithmetic
int arr[5] = {10, 20, 30, 40, 50};
int *ptr = arr;
printf("%d\n", *ptr); // 10
printf("%d\n", *(ptr + 1)); // 20
printf("%d\n", *(ptr + 2)); // 30
ptr++; // Move to next element
printf("%d\n", *ptr); // 20
ptr += 2; // Move 2 elements forward
printf("%d\n", *ptr); // 40
Pointer to Pointer
int x = 10;
int *ptr1 = &x;
int **ptr2 = &ptr1;
printf("%d\n", **ptr2); // 10
// Modify through double pointer
**ptr2 = 20;
printf("%d\n", x); // 20
Null Pointers
int *ptr = NULL; // Initialize to NULL
// Always check before dereferencing
if (ptr != NULL) {
printf("%d\n", *ptr);
} else {
printf("Pointer is NULL\n");
}
Function Pointers
// Function pointer declaration
int (*func_ptr)(int, int);
int add(int a, int b) {
return a + b;
}
int multiply(int a, int b) {
return a * b;
}
int main(void) {
func_ptr = add;
printf("10 + 5 = %d\n", func_ptr(10, 5)); // 15
func_ptr = multiply;
printf("10 * 5 = %d\n", func_ptr(10, 5)); // 50
return 0;
}
// Array of function pointers
int (*operations[2])(int, int) = {add, multiply};
printf("Result: %d\n", operations[0](10, 5)); // 15
Structures (Structs)
Struct Declaration and Initialization
// Define struct
struct Point {
int x;
int y;
};
// Create struct variable
struct Point p1;
p1.x = 10;
p1.y = 20;
// Initialize during declaration
struct Point p2 = {30, 40};
// Designated initializers (C99)
struct Point p3 = {.x = 50, .y = 60};
Typedef with Structs
typedef struct {
char name[50];
int age;
float gpa;
} Student;
// Now can use Student instead of struct Student
Student s1 = {"Alice", 20, 3.8};
printf("Name: %s, Age: %d, GPA: %.2f\n", s1.name, s1.age, s1.gpa);
Nested Structures
typedef struct {
int day;
int month;
int year;
} Date;
typedef struct {
char name[50];
Date birthdate;
float salary;
} Employee;
Employee emp = {"John", {15, 8, 1990}, 50000.0};
printf("Name: %s\n", emp.name);
printf("Birthdate: %d/%d/%d\n", emp.birthdate.day,
emp.birthdate.month, emp.birthdate.year);
Pointers to Structures
typedef struct {
int x;
int y;
} Point;
Point p1 = {10, 20};
Point *ptr = &p1;
// Access members through pointer
printf("x: %d, y: %d\n", (*ptr).x, (*ptr).y);
// Arrow operator (shorthand)
printf("x: %d, y: %d\n", ptr->x, ptr->y);
Arrays of Structures
typedef struct {
char name[30];
int age;
} Person;
Person people[3] = {
{"Alice", 25},
{"Bob", 30},
{"Charlie", 35}
};
for (int i = 0; i < 3; i++) {
printf("%s is %d years old\n", people[i].name, people[i].age);
}
Unions and Enums
Unions
// Union: all members share same memory location
union Data {
int i;
float f;
char c;
};
union Data data;
data.i = 10;
printf("i: %d\n", data.i);
data.f = 3.14; // Overwrites i
printf("f: %.2f\n", data.f);
printf("i: %d\n", data.i); // Corrupted
printf("Size of union: %zu\n", sizeof(union Data)); // Size of largest member
Enumerations
// Define enum
enum Day {
MONDAY, // 0
TUESDAY, // 1
WEDNESDAY, // 2
THURSDAY, // 3
FRIDAY, // 4
SATURDAY, // 5
SUNDAY // 6
};
enum Day today = WEDNESDAY;
// Custom values
enum Status {
SUCCESS = 0,
ERROR = -1,
PENDING = 1,
TIMEOUT = 2
};
// Typedef with enum
typedef enum {
RED,
GREEN,
BLUE
} Color;
Color favorite = BLUE;
File I/O
Opening and Closing Files
#include <stdio.h>
FILE *file = fopen("data.txt", "r"); // Open for reading
if (file == NULL) {
perror("Error opening file");
return 1;
}
// Use file...
fclose(file); // Always close when done
File Modes:
"r"- Read (file must exist)"w"- Write (creates new or truncates existing)"a"- Append (creates new or appends to existing)"r+"- Read and write (file must exist)"w+"- Read and write (creates new or truncates)"a+"- Read and append
Writing to Files
// fprintf - formatted output
FILE *file = fopen("output.txt", "w");
fprintf(file, "Hello, %s!\n", "World");
fprintf(file, "Number: %d\n", 42);
fclose(file);
// fputs - write string
FILE *file = fopen("output.txt", "w");
fputs("Line 1\n", file);
fputs("Line 2\n", file);
fclose(file);
// fwrite - binary write
int numbers[] = {1, 2, 3, 4, 5};
FILE *file = fopen("data.bin", "wb");
fwrite(numbers, sizeof(int), 5, file);
fclose(file);
Reading from Files
// fscanf - formatted input
FILE *file = fopen("input.txt", "r");
int num;
char str[50];
fscanf(file, "%d %s", &num, str);
fclose(file);
// fgets - read line
FILE *file = fopen("input.txt", "r");
char line[100];
while (fgets(line, sizeof(line), file) != NULL) {
printf("%s", line);
}
fclose(file);
// fread - binary read
int numbers[5];
FILE *file = fopen("data.bin", "rb");
fread(numbers, sizeof(int), 5, file);
fclose(file);
// fgetc - read character
FILE *file = fopen("input.txt", "r");
int ch;
while ((ch = fgetc(file)) != EOF) {
putchar(ch);
}
fclose(file);
File Position Functions
FILE *file = fopen("data.txt", "r");
// ftell - get current position
long pos = ftell(file);
// fseek - set position
fseek(file, 0, SEEK_SET); // Beginning of file
fseek(file, 0, SEEK_END); // End of file
fseek(file, 10, SEEK_CUR); // 10 bytes from current position
// rewind - reset to beginning
rewind(file);
fclose(file);
File Error Checking
FILE *file = fopen("data.txt", "r");
if (file == NULL) {
perror("fopen");
return 1;
}
// Check for errors
if (ferror(file)) {
fprintf(stderr, "Error reading file\n");
}
// Check for end of file
if (feof(file)) {
printf("End of file reached\n");
}
fclose(file);
Preprocessor Directives
#include Directive
#include <stdio.h> // System header
#include <stdlib.h>
#include <string.h>
#include "myheader.h" // User-defined header
#define Directive
// Constants
#define PI 3.14159
#define MAX_SIZE 1000
#define BUFFER_LEN 256
// Macros
#define SQUARE(x) ((x) * (x))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
// Multi-line macro
#define SWAP(a, b, type) do { \
type temp = a; \
a = b; \
b = temp; \
} while(0)
// Usage
int result = SQUARE(5); // 25
int max_val = MAX(10, 20); // 20
Conditional Compilation
#define DEBUG 1
#ifdef DEBUG
printf("Debug mode enabled\n");
#endif
#ifndef RELEASE
printf("Not in release mode\n");
#endif
#if DEBUG == 1
printf("Debug level 1\n");
#elif DEBUG == 2
printf("Debug level 2\n");
#else
printf("Debug disabled\n");
#endif
// Prevent multiple inclusion
#ifndef MYHEADER_H
#define MYHEADER_H
// Header contents...
#endif // MYHEADER_H
Predefined Macros
printf("File: %s\n", __FILE__); // Current filename
printf("Line: %d\n", __LINE__); // Current line number
printf("Date: %s\n", __DATE__); // Compilation date
printf("Time: %s\n", __TIME__); // Compilation time
printf("Function: %s\n", __func__); // Current function name (C99)
#undef and #pragma
// Undefine a macro
#define TEMP 100
#undef TEMP
// Compiler-specific directives
#pragma once // Alternative to include guards (non-standard)
#pragma pack(1) // Structure packing
Common Patterns
Error Handling Pattern
int process_file(const char *filename) {
FILE *file = fopen(filename, "r");
if (file == NULL) {
perror("fopen");
return -1;
}
char *buffer = malloc(1024);
if (buffer == NULL) {
fclose(file);
perror("malloc");
return -1;
}
// Process file...
// Cleanup
free(buffer);
fclose(file);
return 0;
}
Generic Swap Function
void swap(void *a, void *b, size_t size) {
unsigned char *p = a;
unsigned char *q = b;
unsigned char temp;
for (size_t i = 0; i < size; i++) {
temp = p[i];
p[i] = q[i];
q[i] = temp;
}
}
// Usage
int x = 10, y = 20;
swap(&x, &y, sizeof(int));
printf("x=%d, y=%d\n", x, y); // x=20, y=10
Linked List Implementation
typedef struct Node {
int data;
struct Node *next;
} Node;
// Insert at beginning
Node* insert_front(Node *head, int data) {
Node *new_node = malloc(sizeof(Node));
if (new_node == NULL) return head;
new_node->data = data;
new_node->next = head;
return new_node;
}
// Print list
void print_list(Node *head) {
Node *current = head;
while (current != NULL) {
printf("%d -> ", current->data);
current = current->next;
}
printf("NULL\n");
}
// Free list
void free_list(Node *head) {
Node *current = head;
while (current != NULL) {
Node *temp = current;
current = current->next;
free(temp);
}
}
Command Line Arguments
int main(int argc, char *argv[]) {
printf("Program name: %s\n", argv[0]);
printf("Number of arguments: %d\n", argc - 1);
for (int i = 1; i < argc; i++) {
printf("Argument %d: %s\n", i, argv[i]);
}
return 0;
}
// Run: ./program arg1 arg2 arg3
// Output:
// Program name: ./program
// Number of arguments: 3
// Argument 1: arg1
// Argument 2: arg2
// Argument 3: arg3
Best Practices
Code Organization
// Use meaningful names
int calculate_average(int *scores, int count); // Good
int calc(int *a, int n); // Avoid
// Use constants instead of magic numbers
#define MAX_STUDENTS 100
int students[MAX_STUDENTS]; // Good
int students[100]; // Avoid
// Group related code
typedef struct {
char name[50];
int age;
} Person;
Person create_person(const char *name, int age);
void print_person(const Person *p);
void free_person(Person *p);
Memory Management
// Always check malloc return value
int *ptr = malloc(sizeof(int) * 100);
if (ptr == NULL) {
fprintf(stderr, "Memory allocation failed\n");
return -1;
}
// Always free allocated memory
free(ptr);
ptr = NULL; // Prevent dangling pointer
// Avoid memory leaks
void bad_function() {
int *data = malloc(sizeof(int) * 100);
if (some_error) {
return; // LEAK! Forgot to free
}
free(data);
}
void good_function() {
int *data = malloc(sizeof(int) * 100);
if (data == NULL) return;
if (some_error) {
free(data); // Clean up before return
return;
}
free(data);
}
Buffer Safety
// Use strncpy instead of strcpy
char dest[20];
strncpy(dest, source, sizeof(dest) - 1);
dest[sizeof(dest) - 1] = '\0'; // Ensure null termination
// Use snprintf instead of sprintf
char buffer[50];
snprintf(buffer, sizeof(buffer), "Value: %d", value);
// Check array bounds
for (int i = 0; i < array_size; i++) {
// Safe access
}
Function Design
// Use const for read-only parameters
int calculate_sum(const int *arr, int size);
// Return error codes
int read_file(const char *filename, char **buffer) {
if (filename == NULL || buffer == NULL) {
return -1; // Invalid parameters
}
FILE *file = fopen(filename, "r");
if (file == NULL) {
return -2; // File open error
}
// Success
return 0;
}
// Use header guards
// myheader.h
#ifndef MYHEADER_H
#define MYHEADER_H
// Declarations...
#endif
Compilation Flags
# Enable warnings
gcc -Wall -Wextra -Werror program.c -o program
# Debug symbols
gcc -g program.c -o program
# Optimization
gcc -O2 program.c -o program
# C standard
gcc -std=c11 program.c -o program
# Combine flags
gcc -Wall -Wextra -O2 -std=c11 program.c -o program
Difference Between Different Const Pointers
In C programming, pointers can be declared with the const qualifier in different ways, leading to different types of constant pointers. Understanding these differences is crucial for writing correct and efficient code.
-
Pointer to a Constant Variable: A pointer to a constant variable means that the value being pointed to cannot be changed through the pointer, but the pointer itself can be changed to point to another variable.
const int *ptr; int a = 10; int b = 20; ptr = &a; // Valid *ptr = 30; // Invalid, cannot change the value of 'a' through ptr ptr = &b; // Valid, can change the pointer to point to 'b' -
Constant Pointer to a Variable: A constant pointer to a variable means that the pointer itself cannot be changed to point to another variable, but the value being pointed to can be changed.
int *const ptr = &a; int a = 10; int b = 20; ptr = &b; // Invalid, cannot change the pointer to point to 'b' *ptr = 30; // Valid, can change the value of 'a' through ptr -
Constant Pointer to a Constant Variable: A constant pointer to a constant variable means that neither the pointer can be changed to point to another variable nor the value being pointed to can be changed.
const int *const ptr = &a; int a = 10; int b = 20; ptr = &b; // Invalid, cannot change the pointer to point to 'b' *ptr = 30; // Invalid, cannot change the value of 'a' through ptr
These different types of constant pointers provide various levels of protection and control over the data and pointers in your program, helping to prevent unintended modifications and ensuring code reliability.
Dynamic Memory Allocation
Dynamic memory allocation allows you to allocate memory at runtime instead of compile time. This is essential for creating data structures of variable size.
Memory Layout
Stack (grows down) | Local variables, function parameters
|
V
=====================|====================== <- Stack limit
^
|
| Free memory
|
V
=====================|====================== <- Heap limit
Heap (grows up) | malloc, calloc, realloc allocations
^
malloc() - Memory Allocation
Allocates memory and returns a void pointer:
#include <stdlib.h>
// Allocate memory for single integer
int *ptr = (int *)malloc(sizeof(int));
if (ptr == NULL) {
printf("Memory allocation failed\n");
return 1;
}
*ptr = 42;
printf("Value: %d\n", *ptr);
free(ptr);
ptr = NULL; // Good practice: set to NULL after free
// Allocate memory for array
int *arr = (int *)malloc(10 * sizeof(int));
if (arr == NULL) {
printf("Memory allocation failed\n");
return 1;
}
arr[0] = 100;
arr[9] = 999;
free(arr);
arr = NULL;
calloc() - Contiguous Memory Allocation
Allocates memory and initializes all bytes to zero:
#include <stdlib.h>
// calloc(number_of_elements, size_of_each_element)
int *arr = (int *)calloc(10, sizeof(int)); // 10 integers, all initialized to 0
if (arr == NULL) {
printf("Memory allocation failed\n");
return 1;
}
for (int i = 0; i < 10; i++) {
printf("%d ", arr[i]); // Prints: 0 0 0 0 0 0 0 0 0 0
}
free(arr);
arr = NULL;
realloc() - Resize Memory
Resizes previously allocated memory block:
#include <stdlib.h>
int *arr = (int *)malloc(5 * sizeof(int));
for (int i = 0; i < 5; i++) arr[i] = i;
// Resize to 10 integers
int *new_arr = (int *)realloc(arr, 10 * sizeof(int));
if (new_arr == NULL) {
printf("Reallocation failed\n");
free(arr); // Original block still exists
return 1;
}
arr = new_arr;
for (int i = 5; i < 10; i++) arr[i] = i;
free(arr);
arr = NULL;
free() - Deallocate Memory
Deallocates previously allocated memory:
#include <stdlib.h>
int *ptr = (int *)malloc(sizeof(int));
*ptr = 42;
// When done, free the memory
free(ptr);
// IMPORTANT: Set to NULL to avoid dangling pointer
ptr = NULL;
Memory Allocation Pattern (Safe)
#include <stdlib.h>
#include <stdio.h>
int main(void) {
// 1. Declare pointer
int *ptr;
// 2. Allocate memory with error checking
ptr = (int *)malloc(sizeof(int));
if (ptr == NULL) {
fprintf(stderr, "Memory allocation failed\n");
return 1;
}
// 3. Use the memory
*ptr = 100;
printf("Value: %d\n", *ptr);
// 4. Free the memory
free(ptr);
// 5. Set to NULL (avoid dangling pointer)
ptr = NULL;
return 0;
}
Dynamic Array Implementation
#include <stdlib.h>
#include <stdio.h>
typedef struct {
int *data;
int size;
int capacity;
} DynamicArray;
// Create dynamic array
DynamicArray* array_create(int initial_capacity) {
DynamicArray *arr = (DynamicArray *)malloc(sizeof(DynamicArray));
if (arr == NULL) return NULL;
arr->data = (int *)malloc(initial_capacity * sizeof(int));
if (arr->data == NULL) {
free(arr);
return NULL;
}
arr->size = 0;
arr->capacity = initial_capacity;
return arr;
}
// Add element to array
int array_push(DynamicArray *arr, int value) {
if (arr->size == arr->capacity) {
// Resize: double the capacity
int new_capacity = arr->capacity * 2;
int *new_data = (int *)realloc(arr->data, new_capacity * sizeof(int));
if (new_data == NULL) return -1;
arr->data = new_data;
arr->capacity = new_capacity;
}
arr->data[arr->size++] = value;
return 0;
}
// Get element from array
int array_get(DynamicArray *arr, int index) {
if (index < 0 || index >= arr->size) {
fprintf(stderr, "Index out of bounds\n");
return -1;
}
return arr->data[index];
}
// Free array
void array_free(DynamicArray *arr) {
if (arr == NULL) return;
free(arr->data);
free(arr);
}
// Usage
int main(void) {
DynamicArray *arr = array_create(10);
if (arr == NULL) {
fprintf(stderr, "Failed to create array\n");
return 1;
}
for (int i = 0; i < 20; i++) {
array_push(arr, i * 10);
}
for (int i = 0; i < arr->size; i++) {
printf("%d ", array_get(arr, i));
}
array_free(arr);
return 0;
}
Common Memory Errors
1. Memory Leak (forgot to free)
void memory_leak(void) {
int *ptr = (int *)malloc(sizeof(int));
*ptr = 42;
// Missing: free(ptr);
// Memory is lost when function exits
}
2. Double Free
int *ptr = (int *)malloc(sizeof(int));
free(ptr);
free(ptr); // ERROR: Undefined behavior!
3. Use After Free (Dangling Pointer)
int *ptr = (int *)malloc(sizeof(int));
*ptr = 42;
free(ptr);
printf("%d\n", *ptr); // ERROR: ptr points to freed memory!
ptr = NULL; // Should do this after free
4. Buffer Overflow
char *str = (char *)malloc(5);
strcpy(str, "Hello World"); // ERROR: Buffer overflow!
// "Hello World" needs 12 bytes, only allocated 5
free(str);
5. Null Pointer Dereference
int *ptr = (int *)malloc(sizeof(int));
// Allocation failed
if (ptr == NULL) {
*ptr = 42; // ERROR: Dereferencing NULL!
}
Best Practices for Dynamic Memory
// 1. Always check if malloc/calloc/realloc succeeded
int *ptr = (int *)malloc(sizeof(int));
if (ptr == NULL) {
// Handle error
return -1;
}
// 2. Use sizeof for type safety
int *arr = (int *)malloc(100 * sizeof(int)); // Good
int *arr = (int *)malloc(100 * 4); // Avoid - hardcoded size
// 3. Free in reverse order of allocation
void *p1 = malloc(100);
void *p2 = malloc(200);
void *p3 = malloc(300);
free(p3);
free(p2);
free(p1);
// 4. Set pointer to NULL after free
free(ptr);
ptr = NULL;
// 5. Avoid memory leaks - create cleanup paths
FILE *file = fopen("data.txt", "r");
int *data = (int *)malloc(1000 * sizeof(int));
if (file == NULL) {
free(data); // Clean up before returning
return -1;
}
// Process file
if (some_error) {
fclose(file);
free(data); // Clean up before returning
return -1;
}
fclose(file);
free(data);
return 0;
// 6. Use wrapper functions for consistency
void* safe_malloc(size_t size) {
void *ptr = malloc(size);
if (ptr == NULL) {
fprintf(stderr, "malloc failed: requested %zu bytes\n", size);
exit(1); // Or handle error differently
}
return ptr;
}
int *arr = (int *)safe_malloc(100 * sizeof(int));
Memory Leak Detection Tools
# Valgrind - memory error detector
valgrind --leak-check=full --show-leak-kinds=all ./program
# AddressSanitizer (GCC/Clang)
gcc -fsanitize=address -g program.c -o program
./program
# Dr. Memory (Windows)
drmemory -leaks_only -- program.exe
Comparison of Allocation Functions
| Function | Initialization | Returns NULL on Fail | Use Case |
|---|---|---|---|
| malloc | No (garbage values) | Yes | When you'll initialize manually |
| calloc | Yes (all zeros) | Yes | When you need zeroed memory |
| realloc | Preserves existing | Yes | When resizing allocations |
Memory Allocation Time Complexity
| Operation | Time Complexity | Space Complexity |
|---|---|---|
| malloc/calloc | O(1) amortized | O(1) |
| free | O(1) amortized | O(1) |
| realloc | O(n) | O(n) |
Commonly Used String Library Functions
The C standard library provides a set of functions for manipulating strings. Here are some commonly used string functions:
-
strlen - Calculate the length of a string:
#include <string.h> size_t length = strlen("example"); -
strcpy - Copy a string:
#include <string.h> char dest[20]; strcpy(dest, "source"); -
strncpy - Copy a specified number of characters from a string:
#include <string.h> char dest[20]; strncpy(dest, "source", 5); -
strcat - Concatenate two strings:
#include <string.h> char dest[20] = "Hello, "; strcat(dest, "World!"); -
strncat - Concatenate a specified number of characters from one string to another:
#include <string.h> char dest[20] = "Hello, "; strncat(dest, "World!", 3); -
strcmp - Compare two strings:
#include <string.h> int result = strcmp("string1", "string2"); -
strncmp - Compare a specified number of characters from two strings:
#include <string.h> int result = strncmp("string1", "string2", 5); -
strchr - Find the first occurrence of a character in a string:
#include <string.h> char *ptr = strchr("example", 'a'); -
strrchr - Find the last occurrence of a character in a string:
#include <string.h> char *ptr = strrchr("example", 'e'); -
strstr - Find the first occurrence of a substring in a string:
#include <string.h> char *ptr = strstr("example", "amp");
These functions cover a variety of common use cases for string manipulation in C, making them essential tools for C programmers.
Variants of printf and scanf
The printf and scanf functions are commonly used for input and output in C. There are several variants of these functions that provide additional functionality.
printf Variants
-
printf- Print formatted output to the standard output:#include <stdio.h> printf("Hello, %s!\n", "World"); -
fprintf- Print formatted output to a file:#include <stdio.h> FILE *file = fopen("output.txt", "w"); fprintf(file, "Hello, %s!\n", "World"); fclose(file); -
sprintf- Print formatted output to a string:#include <stdio.h> char buffer[50]; sprintf(buffer, "Hello, %s!", "World"); -
snprintf- Print formatted output to a string with a limit on the number of characters:#include <stdio.h> char buffer[50]; snprintf(buffer, sizeof(buffer), "Hello, %s!", "World"); -
vprintf- Print formatted output using ava_list:#include <stdio.h> #include <stdarg.h> void my_vprintf(const char *format, ...) { va_list args; va_start(args, format); vprintf(format, args); va_end(args); } -
vfprintf- Print formatted output to a file using ava_list:#include <stdio.h> #include <stdarg.h> void my_vfprintf(FILE *file, const char *format, ...) { va_list args; va_start(args, format); vfprintf(file, format, args); va_end(args); } -
vsprintf- Print formatted output to a string using ava_list:#include <stdio.h> #include <stdarg.h> void my_vsprintf(char *buffer, const char *format, ...) { va_list args; va_start(args, format); vsprintf(buffer, format, args); va_end(args); } -
vsnprintf- Print formatted output to a string with a limit on the number of characters using ava_list:#include <stdio.h> #include <stdarg.h> void my_vsnprintf(char *buffer, size_t size, const char *format, ...) { va_list args; va_start(args, format); vsnprintf(buffer, size, format, args); va_end(args); }
scanf Variants
-
scanf- Read formatted input from the standard input:#include <stdio.h> int value; scanf("%d", &value); -
fscanf- Read formatted input from a file:#include <stdio.h> FILE *file = fopen("input.txt", "r"); int value; fscanf(file, "%d", &value); fclose(file); -
sscanf- Read formatted input from a string:#include <stdio.h> const char *str = "123"; int value; sscanf(str, "%d", &value); -
vscanf- Read formatted input using ava_list:#include <stdio.h> #include <stdarg.h> void my_vscanf(const char *format, ...) { va_list args; va_start(args, format); vscanf(format, args); va_end(args); } -
vfscanf- Read formatted input from a file using ava_list:#include <stdio.h> #include <stdarg.h> void my_vfscanf(FILE *file, const char *format, ...) { va_list args; va_start(args, format); vfscanf(file, format, args); va_end(args); } -
vsscanf- Read formatted input from a string using ava_list:#include <stdio.h> #include <stdarg.h> void my_vsscanf(const char *str, const char *format, ...) { va_list args; va_start(args, format); vsscanf(str, format, args); va_end(args); }
These variants of printf and scanf provide flexibility for different input and output scenarios in C programming.
C++
Overview
C++ is an extension of C that adds object-oriented features and other enhancements.
Key Features
- Object-oriented programming
- Generic programming support
- Standard Template Library (STL)
- Low-level memory manipulation
- High performance
Object Instantiation Patterns
C++ provides multiple ways to create and initialize objects, each with different characteristics regarding memory management, lifetime, and performance.
1. Stack Allocation (Automatic Storage)
Objects created on the stack have automatic lifetime - they're destroyed when they go out of scope.
class MyClass {
public:
int value;
MyClass(int v) : value(v) {
std::cout << "Constructor called: " << value << std::endl;
}
~MyClass() {
std::cout << "Destructor called: " << value << std::endl;
}
};
void example() {
MyClass obj1(10); // Stack allocation
MyClass obj2 = MyClass(20); // Also stack allocation
MyClass obj3{30}; // C++11 uniform initialization
// All objects destroyed automatically when function exits
}
Advantages:
- Fast allocation/deallocation
- Automatic cleanup (RAII)
- No memory leaks
Disadvantages:
- Limited stack size
- Objects can't outlive their scope
2. Heap Allocation with new/delete
Objects created on the heap persist until explicitly deleted.
// Single object
MyClass* ptr1 = new MyClass(100); // Allocate on heap
// Use ptr1...
delete ptr1; // Must manually delete
ptr1 = nullptr; // Good practice
// Array of objects
MyClass* arr = new MyClass[5]; // Default constructor for each
// Use arr...
delete[] arr; // Must use delete[] for arrays
arr = nullptr;
// With initialization (C++11)
MyClass* ptr2 = new MyClass{200};
delete ptr2;
Advantages:
- Objects can outlive their scope
- Larger available memory
- Dynamic sizing
Disadvantages:
- Manual memory management
- Risk of memory leaks
- Slower than stack allocation
3. Smart Pointers (Modern C++)
Smart pointers provide automatic memory management for heap-allocated objects.
#include <memory>
// std::unique_ptr - exclusive ownership
{
std::unique_ptr<MyClass> ptr1 = std::make_unique<MyClass>(10);
// Automatically deleted when ptr1 goes out of scope
// Cannot be copied, only moved
auto ptr2 = std::make_unique<MyClass>(20); // Using auto
std::unique_ptr<MyClass> ptr3 = std::move(ptr2); // Transfer ownership
// ptr2 is now nullptr
}
// std::shared_ptr - shared ownership
{
std::shared_ptr<MyClass> ptr1 = std::make_shared<MyClass>(30);
{
std::shared_ptr<MyClass> ptr2 = ptr1; // Both own the object
std::cout << "Reference count: " << ptr1.use_count() << std::endl; // 2
} // ptr2 destroyed, object still exists
std::cout << "Reference count: " << ptr1.use_count() << std::endl; // 1
} // Object deleted when last shared_ptr is destroyed
// Array with smart pointers (C++17)
auto arr = std::make_unique<MyClass[]>(5);
Advantages:
- Automatic memory management
- Exception-safe
- Clear ownership semantics
Disadvantages:
- Slight overhead (especially shared_ptr)
- Reference counting overhead
4. Initialization Patterns
C++ offers various initialization syntaxes with different behaviors.
class Point {
public:
int x, y;
Point() : x(0), y(0) {}
Point(int x, int y) : x(x), y(y) {}
};
// Default initialization
Point p1; // Calls default constructor: Point()
// Direct initialization
Point p2(10, 20); // Calls Point(int, int)
// Copy initialization
Point p3 = Point(30, 40); // May involve copy/move
// List initialization (Uniform initialization - C++11)
Point p4{50, 60}; // Direct list initialization
Point p5 = {70, 80}; // Copy list initialization
auto p6 = Point{90, 100}; // With auto
// Value initialization
Point p7{}; // Zero-initializes: x=0, y=0
Point* p8 = new Point(); // Value initialization on heap
// Aggregate initialization (for POD types)
struct Data {
int a;
double b;
char c;
};
Data d1 = {1, 2.5, 'x'}; // C-style
Data d2{1, 2.5, 'x'}; // C++11 style
Data d3{.a=1, .b=2.5}; // C++20 designated initializers
5. Constructor Patterns
Different ways to call constructors for initialization.
class Resource {
private:
int* data;
size_t size;
public:
// Default constructor
Resource() : data(nullptr), size(0) {
std::cout << "Default constructor" << std::endl;
}
// Parameterized constructor
Resource(size_t sz) : data(new int[sz]), size(sz) {
std::cout << "Parameterized constructor" << std::endl;
}
// Copy constructor
Resource(const Resource& other) : size(other.size) {
data = new int[size];
std::copy(other.data, other.data + size, data);
std::cout << "Copy constructor" << std::endl;
}
// Move constructor (C++11)
Resource(Resource&& other) noexcept : data(other.data), size(other.size) {
other.data = nullptr;
other.size = 0;
std::cout << "Move constructor" << std::endl;
}
// Destructor
~Resource() {
delete[] data;
std::cout << "Destructor" << std::endl;
}
};
// Usage examples
Resource r1; // Default constructor
Resource r2(100); // Parameterized constructor
Resource r3 = r2; // Copy constructor
Resource r4 = std::move(r2); // Move constructor
Resource r5(std::move(r3)); // Move constructor (explicit)
6. Factory Pattern
Using factory functions for object creation.
class Shape {
public:
virtual void draw() = 0;
virtual ~Shape() = default;
};
class Circle : public Shape {
double radius;
public:
Circle(double r) : radius(r) {}
void draw() override { std::cout << "Drawing circle" << std::endl; }
};
class Rectangle : public Shape {
double width, height;
public:
Rectangle(double w, double h) : width(w), height(h) {}
void draw() override { std::cout << "Drawing rectangle" << std::endl; }
};
// Factory function
std::unique_ptr<Shape> createShape(const std::string& type) {
if (type == "circle") {
return std::make_unique<Circle>(5.0);
} else if (type == "rectangle") {
return std::make_unique<Rectangle>(4.0, 6.0);
}
return nullptr;
}
// Usage
auto shape = createShape("circle");
if (shape) {
shape->draw();
}
7. Placement New
Constructing objects at a specific memory location.
#include <new>
// Pre-allocated buffer
alignas(MyClass) char buffer[sizeof(MyClass)];
// Construct object in buffer
MyClass* obj = new (buffer) MyClass(42);
// Use object
obj->value = 100;
// Must manually call destructor
obj->~MyClass();
// Common use case: memory pools
class MemoryPool {
char buffer[1024];
public:
template<typename T, typename... Args>
T* construct(Args&&... args) {
void* ptr = /* allocate from buffer */;
return new (ptr) T(std::forward<Args>(args)...);
}
};
8. Array Initialization Patterns
Different ways to create and initialize arrays of objects.
// Stack arrays
MyClass arr1[3]; // Default constructor for each
MyClass arr2[3] = {MyClass(1), MyClass(2), MyClass(3)}; // Specific initialization
MyClass arr3[] = {MyClass(10), MyClass(20)}; // Size inferred
// Uniform initialization (C++11)
MyClass arr4[3] = {{1}, {2}, {3}};
MyClass arr5[3]{{1}, {2}, {3}};
// Heap arrays
MyClass* heap_arr1 = new MyClass[5]; // Default constructor
delete[] heap_arr1;
// std::array (C++11)
#include <array>
std::array<MyClass, 3> arr6 = {MyClass(1), MyClass(2), MyClass(3)};
std::array<MyClass, 3> arr7{MyClass(1), MyClass(2), MyClass(3)};
// std::vector (dynamic array)
#include <vector>
std::vector<MyClass> vec1; // Empty vector
std::vector<MyClass> vec2(5); // 5 default-constructed objects
std::vector<MyClass> vec3(5, MyClass(42)); // 5 copies of MyClass(42)
std::vector<MyClass> vec4{MyClass(1), MyClass(2), MyClass(3)}; // Initializer list
9. Emplace Construction
Constructing objects in-place within containers (C++11).
#include <vector>
#include <map>
std::vector<MyClass> vec;
// push_back creates temporary and moves/copies it
vec.push_back(MyClass(10));
// emplace_back constructs directly in the vector (more efficient)
vec.emplace_back(20); // Constructs MyClass(20) in-place
// Similarly for maps
std::map<int, MyClass> myMap;
myMap.emplace(1, MyClass(100)); // Creates pair in-place
myMap.try_emplace(2, 200); // Even better, doesn't construct if key exists
// emplace with multiple arguments
struct Person {
std::string name;
int age;
Person(std::string n, int a) : name(n), age(a) {}
};
std::vector<Person> people;
people.emplace_back("Alice", 30); // Constructs Person directly in vector
10. RAII Pattern (Resource Acquisition Is Initialization)
Tying resource lifetime to object lifetime.
class FileHandler {
FILE* file;
public:
// Resource acquired in constructor
FileHandler(const char* filename, const char* mode) {
file = fopen(filename, mode);
if (!file) throw std::runtime_error("Failed to open file");
}
// Resource released in destructor
~FileHandler() {
if (file) {
fclose(file);
}
}
// Prevent copying
FileHandler(const FileHandler&) = delete;
FileHandler& operator=(const FileHandler&) = delete;
// Allow moving
FileHandler(FileHandler&& other) noexcept : file(other.file) {
other.file = nullptr;
}
FILE* get() { return file; }
};
// Usage - no need to manually close file
void processFile() {
FileHandler handler("data.txt", "r");
// Use handler.get()...
// File automatically closed when handler goes out of scope
}
11. Copy Elision and RVO (Return Value Optimization)
The compiler can optimize away unnecessary copies.
MyClass createObject() {
MyClass obj(100);
return obj; // RVO: object constructed directly in caller's space
}
MyClass obj1 = createObject(); // No copy/move, direct construction (C++17 guaranteed)
// Named Return Value Optimization (NRVO)
MyClass createNamed(int value) {
MyClass result(value);
// ... operations on result
return result; // May be optimized (not guaranteed)
}
Best Practices for Object Instantiation
- Prefer stack allocation when possible - it's fastest and safest
- Use smart pointers instead of raw new/delete for heap allocation
- Use
std::make_uniqueandstd::make_sharedfor creating smart pointers - Use uniform initialization
{}to avoid most vexing parse and narrowing conversions - Use
emplacemethods in containers for in-place construction - Follow RAII principles for resource management
- Prefer
std::vectorandstd::arrayover raw arrays - Avoid naked
new- use smart pointers or containers
// Good practices example
void goodPractices() {
// Stack allocation when lifetime is scoped
MyClass local(42);
// Smart pointers for heap allocation
auto ptr = std::make_unique<MyClass>(100);
// Uniform initialization
MyClass obj{50};
// Containers for collections
std::vector<MyClass> vec;
vec.emplace_back(10);
vec.emplace_back(20);
// RAII for resources
std::ifstream file("data.txt");
// File automatically closed
}
C++ Strings and Their Methods
In C++, the std::string class provides a powerful and flexible way to handle strings. It offers a variety of methods for string manipulation, making it easier to perform common operations without dealing with low-level character arrays. Below are some of the most commonly used std::string methods in detail:
1. Constructors
std::string offers multiple constructors to initialize strings in different ways.
#include <string>
// Default constructor
std::string str1;
// Constructor with a C-string
std::string str2("Hello, World!");
// Constructor with a specific number of repeated characters
std::string str3(5, 'a'); // "aaaaa"
// Copy constructor
std::string str4(str2);
// Substring constructor
std::string str5(str2, 7, 5); // "World"
2. Size and Capacity
size()/length(): Returns the number of characters in the string.capacity(): Returns the size of the storage space currently allocated for the string.
std::string str = "Example";
size_t len = str.size(); // 7
size_t cap = str.capacity(); // Implementation-defined
3. Accessing Characters
operator[]: Accesses character at a specific index.at(): Accesses character at a specific index with bounds checking.front()/back(): Accesses the first and last characters.
std::string str = "Hello";
char ch = str[1]; // 'e'
char ch_at = str.at(2); // 'l'
char first = str.front(); // 'H'
char last = str.back(); // 'o'
4. Modifiers
append(): Adds characters to the end of the string.clear(): Removes all characters from the string.insert(): Inserts characters at a specified position.erase(): Removes characters from a specified position.replace(): Replaces part of the string with another string.
std::string str = "Hello";
str.append(", World!"); // "Hello, World!"
str.insert(5, " C++"); // "Hello C++, World!"
str.erase(5, 6); // "HelloWorld!"
str.replace(5, 5, " C++"); // "Hello C++!"
str.clear(); // ""
5. Substring and Extracting
substr(): Returns a substring starting from a specified position.
std::string str = "Hello, World!";
std::string sub = str.substr(7, 5); // "World"
6. Finding Characters and Substrings
find(): Searches for a substring or character and returns the position.rfind(): Searches for a substring or character from the end.
std::string str = "Hello, World!";
size_t pos = str.find("World"); // 7
size_t rpos = str.rfind('o'); // 8
7. Comparison
compare(): Compares two strings.
std::string str1 = "apple";
std::string str2 = "banana";
int result = str1.compare(str2);
// result < 0 since "apple" < "banana"
8. Conversion to C-string
c_str(): Returns a C-style null-terminated string.
std::string str = "Hello";
const char* cstr = str.c_str();
9. Iterators
std::string supports iterators to traverse the string.
std::string str = "Hello";
for (std::string::iterator it = str.begin(); it != str.end(); ++it) {
std::cout << *it << ' ';
}
// Output: H e l l o
10. Emplace and Emplace_back
emplace(): Constructs and inserts a substring.emplace_back(): Appends a character to the end of the string.
std::string str = "Hello";
str.emplace(str.size(), '!'); // "Hello!"
str.emplace_back('?'); // "Hello!?"
11. Swap
swap(): Swaps the contents of two strings.
std::string str1 = "Hello";
std::string str2 = "World";
str1.swap(str2);
// str1 is now "World", str2 is now "Hello"
12. Transform
You can apply transformations to each character using algorithms.
#include <algorithm>
std::string str = "Hello";
std::transform(str.begin(), str.end(), str.begin(), ::toupper); // "HELLO"
13. Other Useful Methods
empty(): Checks if the string is empty.find_first_of()/find_last_of(): Finds the first/last occurrence of any character from a set.find_first_not_of()/find_last_not_of(): Finds the first/last character not in a set.
std::string str = "Hello";
bool isEmpty = str.empty(); // false
size_t pos = str.find_first_of('e'); // 1
size_t not_pos = str.find_first_not_of('H'); // 1
Example Usage
#include <iostream>
#include <string>
int main() {
std::string greeting = "Hello";
greeting += ", World!"; // Using operator +=
std::cout << greeting << std::endl; // Output: Hello, World!
// Find and replace
size_t pos = greeting.find("World");
if (pos != std::string::npos) {
greeting.replace(pos, 5, "C++");
}
std::cout << greeting << std::endl; // Output: Hello, C++!
return 0;
}
Understanding and utilizing these std::string methods can greatly enhance your ability to manipulate and manage text in C++ applications effectively.
C++ Vectors and Their Methods
In C++, the std::vector class template provides a dynamic array that can resize itself automatically when elements are added or removed. It offers numerous methods to manipulate the data efficiently. Below are detailed explanations and examples of various std::vector methods:
1. Constructors
std::vector offers multiple constructors to initialize vectors in different ways.
#include <vector>
// Default constructor
std::vector<int> vec1;
// Constructor with a specific size
std::vector<int> vec2(5); // {0, 0, 0, 0, 0}
// Constructor with a specific size and initial value
std::vector<int> vec3(5, 10); // {10, 10, 10, 10, 10}
// Initializer list constructor
std::vector<int> vec4 = {1, 2, 3, 4, 5};
// Copy constructor
std::vector<int> vec5(vec4);
2. Size and Capacity
size(): Returns the number of elements in the vector.capacity(): Returns the size of the storage space currently allocated for the vector, expressed in terms of elements.empty(): Checks whether the vector is empty.
std::vector<int> vec = {1, 2, 3};
size_t sz = vec.size(); // 3
size_t cap = vec.capacity(); // >= 3
bool isEmpty = vec.empty(); // false
3. Element Access
operator[]: Accesses element at a specific index without bounds checking.at(): Accesses element at a specific index with bounds checking.front(): Accesses the first element.back(): Accesses the last element.data(): Returns a pointer to the underlying array.
std::vector<int> vec = {10, 20, 30, 40, 50};
int first = vec[0]; // 10
int third = vec.at(2); // 30
int front = vec.front(); // 10
int back = vec.back(); // 50
int* ptr = vec.data(); // Pointer to the first element
4. Modifiers
push_back(): Adds an element to the end of the vector.pop_back(): Removes the last element of the vector.insert(): Inserts elements at a specified position.erase(): Removes elements from a specified position or range.clear(): Removes all elements from the vector.resize(): Changes the number of elements stored.shrink_to_fit(): Reduces capacity to fit the size.
std::vector<int> vec = {1, 2, 3};
// push_back
vec.push_back(4); // {1, 2, 3, 4}
// pop_back
vec.pop_back(); // {1, 2, 3}
// insert
vec.insert(vec.begin() + 1, 10); // {1, 10, 2, 3}
// erase single element
vec.erase(vec.begin() + 2); // {1, 10, 3}
// erase range
vec.erase(vec.begin(), vec.begin() + 1); // {10, 3}
// clear
vec.clear(); // {}
// resize
vec.resize(5, 100); // {100, 100, 100, 100, 100}
// shrink_to_fit
vec.shrink_to_fit();
5. Iterators
Vectors support iterators to traverse and manipulate elements.
begin(): Returns an iterator to the first element.end(): Returns an iterator to one past the last element.rbegin(): Returns a reverse iterator to the last element.rend(): Returns a reverse iterator to one before the first element.
std::vector<int> vec = {1, 2, 3, 4, 5};
// Forward iteration
for(auto it = vec.begin(); it != vec.end(); ++it) {
std::cout << *it << " ";
}
// Reverse iteration
for(auto it = vec.rbegin(); it != vec.rend(); ++it) {
std::cout << *it << " ";
}
6. Algorithms Support
Vectors work seamlessly with standard algorithms from the C++ Standard Library.
#include <algorithm>
std::vector<int> vec = {5, 3, 1, 4, 2};
// Sort the vector
std::sort(vec.begin(), vec.end()); // {1, 2, 3, 4, 5}
// Reverse the vector
std::reverse(vec.begin(), vec.end()); // {5, 4, 3, 2, 1}
// Find an element
auto it = std::find(vec.begin(), vec.end(), 3);
if(it != vec.end()) {
std::cout << "Found: " << *it << std::endl;
}
7. Capacity Management
reserve(): Increases the capacity of the vector to a value that's greater or equal to the specified.capacity(): Explained earlier.
std::vector<int> vec;
vec.reserve(100); // Reserve space for 100 elements
std::cout << "Capacity: " << vec.capacity() << std::endl;
Understanding and utilizing std::vector and its various methods can significantly enhance the efficiency and flexibility of your C++ programs, allowing for dynamic memory management and rich data manipulation capabilities.
4. Maps
C++ provides the std::map container, which is an associative container that stores elements formed by a combination of a key and a value. std::map automatically sorts its elements by key and allows fast retrieval of individual elements based on their keys.
Constructors
std::map offers multiple constructors to initialize maps in different ways.
#include <map>
#include <string>
// Default constructor
std::map<int, std::string> map1;
// Initializer list constructor
std::map<int, std::string> map2 = {
{1, "one"},
{2, "two"},
{3, "three"}
};
// Range constructor
std::vector<std::pair<int, std::string>> vec = { {4, "four"}, {5, "five"} };
std::map<int, std::string> map3(vec.begin(), vec.end());
// Copy constructor
std::map<int, std::string> map4(map2);
Size and Capacity
size(): Returns the number of elements in the map.empty(): Checks whether the map is empty.
std::map<int, std::string> map = { {1, "one"}, {2, "two"}, {3, "three"} };
size_t sz = map.size(); // 3
bool isEmpty = map.empty(); // false
Element Access
operator[]: Accesses or inserts elements with the given key.at(): Accesses elements with bounds checking.find(): Finds an element with a specific key.count(): Returns the number of elements with a specific key.
// Using operator[]
map[4] = "four"; // Inserts if key 4 does not exist
// Using at()
try {
std::string value = map.at(2); // "two"
} catch(const std::out_of_range& e) {
// Handle error
}
// Using find()
auto it = map.find(3);
if(it != map.end()) {
std::cout << "Found: " << it->second << std::endl; // "three"
}
// Using count()
if(map.count(5)) {
std::cout << "Key 5 exists." << std::endl;
} else {
std::cout << "Key 5 does not exist." << std::endl;
}
Inserting Elements
insert(): Inserts elements into the map.emplace(): Constructs elements in-place.
// Using insert()
map.insert({1, "one"});
map.insert(std::pair<int, std::string>(2, "two"));
// Using emplace()
map.emplace(3, "three");
Deleting Elements
erase(): Removes elements by key or iterator.clear(): Removes all elements from the map.
std::map<int, std::string> map = { {1, "one"}, {2, "two"}, {3, "three"} };
// Erase by key
map.erase(2);
// Erase by iterator
auto itErase = map.find(3);
if(itErase != map.end()) {
map.erase(itErase);
}
// Clear all elements
map.clear();
Iterating Through a Map
std::map<int, std::string> map = { {1, "one"}, {2, "two"}, {3, "three"} };
// Using iterator
for(auto it = map.begin(); it != map.end(); ++it) {
std::cout << it->first << ": " << it->second << std::endl;
}
// Using range-based for loop
for(const auto& pair : map) {
std::cout << pair.first << ": " << pair.second << std::endl;
}
Understanding and utilizing std::map and its various methods can greatly enhance your ability to manage key-value pairs efficiently in C++ applications.
4. Smart Pointers
Smart pointers in C++ are template classes provided by the Standard Library that facilitate automatic and exception-safe memory management. They help manage dynamically allocated objects by ensuring that resources are properly released when they are no longer needed, thus preventing memory leaks and other related issues. C++ offers several types of smart pointers, each tailored to specific use cases and ownership semantics.
Types of Smart Pointers
std::unique_ptrstd::shared_ptrstd::weak_ptr
1. std::unique_ptr
std::unique_ptr is a smart pointer that owns and manages another object through a pointer and disposes of that object when the unique_ptr goes out of scope. It ensures exclusive ownership, meaning that there can be only one unique_ptr instance owning a particular object at any given time.
Key Characteristics:
- Exclusive Ownership: Only one
std::unique_ptrcan own the object at a time. - No Copying:
unique_ptrcannot be copied to prevent multiple ownerships. However, it can be moved. - Lightweight: Minimal overhead compared to raw pointers.
Usage Example:
#include <memory>
#include <iostream>
int main() {
// Creating a unique_ptr to an integer
std::unique_ptr<int> ptr1(new int(10));
std::cout << "Value: " << *ptr1 << std::endl; // Output: Value: 10
// Transferring ownership using std::move
std::unique_ptr<int> ptr2 = std::move(ptr1);
if (!ptr1) {
std::cout << "ptr1 is now null." << std::endl;
}
std::cout << "Value: " << *ptr2 << std::endl; // Output: Value: 10
// Automatic deletion when ptr2 goes out of scope
return 0;
}
Common Methods:
get(): Returns the raw pointer.release(): Releases ownership of the managed object and returns the pointer.reset(): Deletes the currently managed object and takes ownership of a new one.operator*andoperator->: Dereference operators to access the managed object.
2. std::shared_ptr
std::shared_ptr is a smart pointer that maintains shared ownership of an object through a pointer. Multiple shared_ptr instances can own the same object, and the object is destroyed only when the last shared_ptr owning it is destroyed or reset.
Key Characteristics:
- Shared Ownership: Multiple
shared_ptrinstances can own the same object. - Reference Counting: Keeps track of how many
shared_ptrinstances own the object. - Thread-Safe Reference Counting: Safe to use in multi-threaded applications for reference counting operations.
Usage Example:
#include <memory>
#include <iostream>
int main() {
// Creating a shared_ptr to an integer
std::shared_ptr<int> ptr1 = std::make_shared<int>(20);
std::cout << "Value: " << *ptr1 << ", Count: " << ptr1.use_count() << std::endl; // Output: Value: 20, Count: 1
// Creating another shared_ptr sharing the same object
std::shared_ptr<int> ptr2 = ptr1;
std::cout << "Value: " << *ptr2 << ", Count: " << ptr1.use_count() << std::endl; // Output: Value: 20, Count: 2
// Resetting ptr1
ptr1.reset();
std::cout << "ptr1 reset. Count: " << ptr2.use_count() << std::endl; // Output: Count: 1
// Automatic deletion when ptr2 goes out of scope
return 0;
}
Common Methods:
use_count(): Returns the number ofshared_ptrinstances sharing ownership.unique(): Checks if theshared_ptris the only owner.reset(): Releases ownership of the managed object.swap(): Exchanges the managed object with anothershared_ptr.
3. std::weak_ptr
std::weak_ptr is a smart pointer that holds a non-owning ("weak") reference to an object that is managed by std::shared_ptr. It is used to prevent circular references that can lead to memory leaks by allowing one part of the code to observe an object without affecting its lifetime.
Key Characteristics:
- Non-Owning: Does not contribute to the reference count.
- Avoids Circular References: Useful in scenarios like bidirectional relationships.
- Access Controlled: Must be converted to
std::shared_ptrto access the managed object.
Usage Example:
#include <memory>
#include <iostream>
struct Node {
int value;
std::shared_ptr<Node> next;
std::weak_ptr<Node> prev; // Using weak_ptr to prevent circular reference
Node(int val) : value(val), next(nullptr), prev() {}
};
int main() {
auto node1 = std::make_shared<Node>(1);
auto node2 = std::make_shared<Node>(2);
node1->next = node2;
node2->prev = node1; // weak_ptr does not increase reference count
std::cout << "Node1 value: " << node1->value << std::endl;
std::cout << "Node2 value: " << node2->value << std::endl;
// Accessing the previous node
if(auto prev = node2->prev.lock()) {
std::cout << "Node2's previous node value: " << prev->value << std::endl;
} else {
std::cout << "Previous node no longer exists." << std::endl;
}
return 0;
}
Common Methods:
lock(): Attempts to acquire astd::shared_ptrto the managed object.expired(): Checks if the managed object has been deleted.reset(): Releases the managed object reference.
Common Methods Across Smart Pointers
While each smart pointer type has its specific methods, there are several common methods that they share:
-
get(): Returns the raw pointer managed by the smart pointer.std::unique_ptr<int> ptr = std::make_unique<int>(100); int* rawPtr = ptr.get(); std::cout << "Raw pointer value: " << *rawPtr << std::endl; // Output: 100 -
reset(): Releases the ownership of the managed object and optionally takes ownership of a new object.std::shared_ptr<int> ptr = std::make_shared<int>(200); ptr.reset(new int(300)); // Old object is deleted, ptr now owns the new object std::cout << "New value: " << *ptr << std::endl; // Output: 300 -
swap(): Exchanges the managed objects of two smart pointers.std::unique_ptr<int> ptr1 = std::make_unique<int>(400); std::unique_ptr<int> ptr2 = std::make_unique<int>(500); ptr1.swap(ptr2); std::cout << "ptr1: " << *ptr1 << ", ptr2: " << *ptr2 << std::endl; // Output: ptr1: 500, ptr2: 400 -
Dereference Operators (
*and->): Access the managed object.std::shared_ptr<std::string> ptr = std::make_shared<std::string>("Hello"); std::cout << "String: " << *ptr << std::endl; // Output: Hello std::cout << "String length: " << ptr->length() << std::endl; // Output: 5
Best Practices
-
Prefer
std::make_uniqueandstd::make_shared: These functions are exception-safe and more efficient.auto ptr = std::make_unique<MyClass>(); auto sharedPtr = std::make_shared<MyClass>(); -
Use
std::unique_ptrWhen Ownership is Exclusive: It clearly signifies ownership semantics and incurs no overhead of reference counting.std::unique_ptr<Resource> resource = std::make_unique<Resource>(); -
Use
std::shared_ptrWhen Ownership is Shared: Useful in scenarios where multiple parts of the program need to share access to the same resource.std::shared_ptr<Logger> logger1 = std::make_shared<Logger>(); std::shared_ptr<Logger> logger2 = logger1; -
Avoid
std::shared_ptrUnless Necessary: It introduces overhead due to reference counting. Use it only when shared ownership is required. -
Break Circular References with
std::weak_ptr: When two objects share ownership viastd::shared_ptr, usestd::weak_ptrto prevent memory leaks.struct A { std::shared_ptr<B> b_ptr; }; struct B { std::weak_ptr<A> a_ptr; // weak_ptr breaks the circular reference };
Understanding and effectively utilizing smart pointers is crucial for modern C++ programming. They not only simplify memory management but also enhance the safety and performance of applications by preventing common issues related to dynamic memory allocation.
5. std::function and std::bind
std::function and std::bind are powerful utilities in the C++ Standard Library that facilitate higher-order programming by allowing functions to be treated as first-class objects. They enable the storage, modification, and invocation of functions in a flexible and generic manner, enhancing the capabilities of callback mechanisms, event handling, and functional programming paradigms in C++.
std::function
std::function is a versatile, type-erased function wrapper that can store any callable target—such as free functions, member functions, lambda expressions, or other function objects—provided they match a specific function signature. This flexibility makes it an essential tool for designing callback interfaces and managing dynamic function invocation.
Key Characteristics:
- Type-Erasure: Abstracts away the specific type of the callable, allowing different types of callable objects to be stored in the same
std::functionvariable. - Copyable and Assignable:
std::functioninstances can be copied and assigned, enabling their use in standard containers and algorithms. - Invoke Any Callable: Can represent free functions, member functions, lambda expressions, and function objects.
Basic Usage Example:
#include <functional>
#include <iostream>
// A free function
int add(int a, int b) {
return a + b;
}
int main() {
// Storing a free function in std::function
std::function<int(int, int)> func = add;
std::cout << "add(2, 3) = " << func(2, 3) << std::endl; // Output: 5
// Storing a lambda expression
std::function<int(int, int)> lambdaFunc = [](int a, int b) -> int {
return a * b;
};
std::cout << "lambdaFunc(2, 3) = " << lambdaFunc(2, 3) << std::endl; // Output: 6
// Storing a member function (requires binding)
struct Calculator {
int subtract(int a, int b) const {
return a - b;
}
};
Calculator calc;
std::function<int(int, int)> memberFunc = std::bind(&Calculator::subtract, &calc, std::placeholders::_1, std::placeholders::_2);
std::cout << "calc.subtract(5, 3) = " << memberFunc(5, 3) << std::endl; // Output: 2
return 0;
}
Common Methods:
operator(): Invokes the stored callable.target(): Retrieves a pointer to the stored callable if it matches a specific type.reset(): Clears the stored callable, making thestd::functionempty.
std::bind
std::bind is a utility that allows you to create a new function object by binding some or all of the arguments of an existing function to specific values. This is particularly useful for adapting functions to match desired interfaces or for creating callbacks with pre-specified arguments.
Key Characteristics:
- Argument Binding: Fixes certain arguments of a function, producing a new function object with fewer parameters.
- Placeholders: Uses placeholders like
std::placeholders::_1to indicate arguments that will be provided later. - Supports Various Callables: Can bind free functions, member functions, and function objects.
Basic Usage Example:
#include <functional>
#include <iostream>
// A free function
int multiply(int a, int b) {
return a * b;
}
struct Calculator {
int divide(int a, int b) const {
if(b == 0) throw std::invalid_argument("Division by zero");
return a / b;
}
};
int main() {
// Binding the first argument of multiply to 5
auto timesFive = std::bind(multiply, 5, std::placeholders::_1);
std::cout << "multiply(5, 4) = " << timesFive(4) << std::endl; // Output: 20
// Binding a member function with the object instance
Calculator calc;
auto divideBy = std::bind(&Calculator::divide, &calc, std::placeholders::_1, 2);
std::cout << "calc.divide(10, 2) = " << divideBy(10) << std::endl; // Output: 5
return 0;
}
Common Use Cases:
- Creating Callbacks: Adapting functions to match callback interfaces that require a specific signature.
- Event Handling: Binding member functions of objects to event handlers with predefined arguments.
- Functional Programming: Enabling partial application and currying of functions for more functional-style code.
Advanced Usage Example:
#include <functional>
#include <iostream>
#include <vector>
class Logger {
public:
void log(const std::string& message, int level) const {
std::cout << "Level " << level << ": " << message << std::endl;
}
};
int main() {
Logger logger;
// Binding the logger object and log level to create a simplified log function
auto infoLog = std::bind(&Logger::log, &logger, std::placeholders::_1, 1);
auto errorLog = std::bind(&Logger::log, &logger, std::placeholders::_1, 3);
infoLog("This is an informational message."); // Output: Level 1: This is an informational message.
errorLog("This is an error message."); // Output: Level 3: This is an error message.
// Storing bind expressions in a std::vector of std::function
std::vector<std::function<void(const std::string&)>> logs;
logs.push_back(infoLog);
logs.push_back(errorLog);
for(auto& logFunc : logs) {
logFunc("Logging through stored function.");
}
// Output:
// Level 1: Logging through stored function.
// Level 3: Logging through stored function.
return 0;
}
Best Practices:
-
Prefer Lambda Expressions Over
std::bind: Lambdas often provide clearer and more readable syntax compared tostd::bind.// Using std::bind auto timesFive = std::bind(multiply, 5, std::placeholders::_1); // Equivalent using a lambda auto timesFiveLambda = [](int a) -> int { return multiply(5, a); }; -
Use
std::functionfor Flexibility: When storing or passing callable objects that may vary in type, usestd::functionto accommodate different callables. -
Avoid Unnecessary Bindings: Excessive use of
std::bindcan lead to less readable code. Assess whether a lambda or a direct function call may be more appropriate.
By leveraging std::function and std::bind, developers can create more abstract, flexible, and reusable code components, facilitating sophisticated callback mechanisms and enhancing the expressive power of C++.
C++ in Competitive Programming
Competitive programming demands not only a deep understanding of algorithms and data structures but also the ability to implement them efficiently within strict time and memory constraints. C++ is a favored language in this arena due to its performance, rich Standard Template Library (STL), and powerful language features. Below are various methods and techniques in C++ that are extensively used in competitive programming:
1. Fast Input/Output
Efficient handling of input and output can significantly reduce execution time, especially with large datasets.
-
Untie C++ Streams from C Streams:
std::ios::sync_with_stdio(false); std::cin.tie(nullptr);Disabling the synchronization between C and C++ standard streams and untieing
cinfromcoutcan speed up I/O operations. -
Use of
scanfandprintf: For even faster I/O, some competitors prefer using C-style I/O functions.
2. Utilizing the Standard Template Library (STL)
The STL provides a suite of ready-to-use data structures and algorithms that can save time and reduce the likelihood of bugs.
-
Vectors (
std::vector): Dynamic arrays that allow for efficient random access and dynamic resizing.std::vector<int> vec = {1, 2, 3}; vec.push_back(4); -
Pairs and Tuples (
std::pair,std::tuple): Useful for storing multiple related values.std::pair<int, int> p = {1, 2}; std::tuple<int, int, int> t = {1, 2, 3}; -
Sets and Maps (
std::set,std::map): Efficiently handle unique elements and key-value associations. -
Algorithms (
std::sort,std::binary_search, etc.): Implement common algorithms with optimized performance.
3. Graph Representations and Algorithms
Graphs are a staple in competitive programming problems. Efficient representation and traversal are crucial.
-
Adjacency List:
int n; // Number of nodes std::vector<std::vector<int>> adj(n + 1); adj[u].push_back(v); adj[v].push_back(u); // For undirected graphs -
Depth-First Search (DFS) and Breadth-First Search (BFS): Fundamental traversal techniques.
-
Dijkstra's and Floyd-Warshall Algorithms: For shortest path problems.
4. Dynamic Programming (DP)
DP is essential for solving optimization problems by breaking them down into simpler subproblems.
-
Memoization and Tabulation:
// Example of Fibonacci using memoization long long fib(int n, std::vector<long long> &dp) { if(n <= 1) return n; if(dp[n] != -1) return dp[n]; return dp[n] = fib(n-1, dp) + fib(n-2, dp); } -
State Optimization: Reducing space complexity by optimizing states.
5. Greedy Algorithms
These algorithms make the locally optimal choice at each step with the hope of finding the global optimum.
-
Interval Scheduling: Selecting the maximum number of non-overlapping intervals.
-
Huffman Coding: For efficient encoding.
6. Bit Manipulation
Bitwise operations can optimize certain calculations and are useful in problems involving subsets or binary representations.
-
Common Operations:
- Setting a bit:
x | (1 << pos) - Clearing a bit:
x & ~(1 << pos) - Toggling a bit:
x ^ (1 << pos)
- Setting a bit:
-
Bitmask DP: Using bitmasks to represent states in DP.
7. Number Theory
Many problems involve mathematical concepts such as primes, GCD, and modular arithmetic.
-
Sieve of Eratosthenes: For finding all prime numbers up to a certain limit.
std::vector<bool> is_prime(n+1, true); is_prime[0] = is_prime[1] = false; for(int i=2; i*i <= n; ++i){ if(is_prime[i]){ for(int j=i*i; j<=n; j+=i){ is_prime[j] = false; } } } -
Modular Exponentiation: Efficiently computing large exponents under a modulus.
long long power(long long a, long long b, long long mod){ long long res = 1; a %= mod; while(b > 0){ if(b & 1) res = res * a % mod; a = a * a % mod; b >>= 1; } return res; }
8. String Algorithms
Handling and processing strings efficiently is vital in many problems.
-
KMP Algorithm: For pattern matching with linear time complexity.
-
Trie Data Structure: Efficiently storing and searching a dynamic set of strings.
9. Data Structures
Choosing the right data structure can make or break your solution.
-
Segment Trees and Binary Indexed Trees (Fenwick Trees): For range queries and updates.
-
Disjoint Set Union (DSU): For efficiently handling union and find operations.
struct DSU { std::vector<int> parent; DSU(int n) : parent(n+1) { for(int i=0;i<=n;i++) parent[i] = i; } int find_set(int x) { return parent[x] == x ? x : parent[x] = find_set(parent[x]); } void union_set(int x, int y) { parent[find_set(x)] = find_set(y); } }; -
Heaps (
std::priority_queue): Useful for efficiently retrieving the maximum or minimum element.
10. Advanced Techniques
-
Meet in the Middle: Breaking problems into two halves to reduce time complexity.
-
Bitmasking and Enumeration: Enumerating all subsets or combinations efficiently.
Best Practices
-
Understand the Problem Thoroughly: Carefully read and comprehend the problem constraints and requirements before jumping into coding.
-
Practice Code Implementation: Regularly practice implementing various algorithms and data structures to build speed and accuracy.
-
Optimize and Test: Continuously look for optimizations and thoroughly test your code against different cases to ensure correctness.
-
Stay Updated: Keep abreast of new algorithms and techniques emerging in the competitive programming community.
By mastering these methods and leveraging C++'s powerful features, competitive programmers can efficiently tackle a wide array of challenging problems and excel in contests.
JavaScript Programming
Overview
JavaScript is a high-level, interpreted programming language primarily used for web development. It enables interactive web pages and is an essential part of web applications alongside HTML and CSS.
Key Features:
- Event-driven, functional, and imperative programming styles
- Dynamic typing
- Prototype-based object-orientation
- First-class functions
- Runs in browsers and on servers (Node.js)
- Asynchronous programming with Promises and async/await
Basic Syntax
Variables
// var (function-scoped, avoid in modern code)
var x = 10;
// let (block-scoped, can be reassigned)
let y = 20;
y = 30; // OK
// const (block-scoped, cannot be reassigned)
const z = 40;
// z = 50; // ERROR!
// But const objects can be modified
const obj = { name: "Alice" };
obj.name = "Bob"; // OK
obj.age = 30; // OK
Data Types
// Primitives
let num = 42; // Number
let str = "Hello"; // String
let bool = true; // Boolean
let undef = undefined; // Undefined
let nul = null; // Null
let sym = Symbol("id"); // Symbol (ES6)
let bigInt = 123n; // BigInt (ES2020)
// Objects
let obj = { name: "Alice" };
let arr = [1, 2, 3];
let func = function() {};
// Type checking
typeof num; // "number"
typeof str; // "string"
typeof obj; // "object"
typeof arr; // "object" (arrays are objects)
Array.isArray(arr); // true
// Type conversion
String(42); // "42"
Number("42"); // 42
parseInt("42"); // 42
parseFloat("3.14"); // 3.14
Boolean(0); // false
Boolean(1); // true
Template Literals (ES6)
const name = "Alice";
const age = 30;
// Template literals
const message = `Hello, ${name}! You are ${age} years old.`;
// Multi-line strings
const multiline = `
This is a
multi-line
string
`;
// Tagged templates
function highlight(strings, ...values) {
return strings.reduce((acc, str, i) => {
return acc + str + (values[i] ? `<strong>${values[i]}</strong>` : '');
}, '');
}
const result = highlight`Name: ${name}, Age: ${age}`;
Arrays
// Creating arrays
const arr = [1, 2, 3, 4, 5];
const mixed = [1, "hello", true, null, { name: "Alice" }];
const empty = [];
// Accessing elements
const first = arr[0]; // 1
const last = arr[arr.length - 1]; // 5
// Common methods
arr.push(6); // Add to end: [1, 2, 3, 4, 5, 6]
arr.pop(); // Remove from end: 6
arr.unshift(0); // Add to start: [0, 1, 2, 3, 4, 5]
arr.shift(); // Remove from start: 0
arr.splice(2, 1); // Remove 1 element at index 2
arr.slice(1, 3); // Extract [2, 3]
// Iteration methods
arr.forEach((item, index) => {
console.log(index, item);
});
// Map (transform array)
const squares = arr.map(x => x * x);
// Filter (select elements)
const evens = arr.filter(x => x % 2 === 0);
// Reduce (aggregate)
const sum = arr.reduce((acc, val) => acc + val, 0);
// Find
const found = arr.find(x => x > 3); // First element > 3
const foundIndex = arr.findIndex(x => x > 3);
// Some and Every
const hasEven = arr.some(x => x % 2 === 0); // true if any even
const allEven = arr.every(x => x % 2 === 0); // true if all even
// Sorting
arr.sort((a, b) => a - b); // Ascending
arr.sort((a, b) => b - a); // Descending
// Spread operator
const arr1 = [1, 2, 3];
const arr2 = [4, 5, 6];
const combined = [...arr1, ...arr2]; // [1, 2, 3, 4, 5, 6]
// Destructuring
const [first, second, ...rest] = [1, 2, 3, 4, 5];
// first = 1, second = 2, rest = [3, 4, 5]
Objects
// Creating objects
const person = {
name: "Alice",
age: 30,
greet() {
return `Hello, I'm ${this.name}`;
}
};
// Accessing properties
person.name; // "Alice"
person["age"]; // 30
// Adding/modifying properties
person.email = "alice@example.com";
person.age = 31;
// Deleting properties
delete person.email;
// Object methods
Object.keys(person); // ["name", "age", "greet"]
Object.values(person); // ["Alice", 31, function]
Object.entries(person); // [["name", "Alice"], ["age", 31], ...]
// Spread operator
const person2 = { ...person, city: "NYC" };
// Destructuring
const { name, age } = person;
const { name: personName, age: personAge } = person; // Rename
// Computed property names
const key = "dynamicKey";
const obj = {
[key]: "value"
};
// Object shorthand (ES6)
const name = "Bob";
const age = 25;
const user = { name, age }; // Same as { name: name, age: age }
// Object.assign (merge objects)
const merged = Object.assign({}, person, { city: "NYC" });
// Freeze object (immutable)
Object.freeze(person);
Functions
Function Declaration
// Traditional function
function greet(name) {
return `Hello, ${name}!`;
}
// Function with default parameters
function greet(name = "World") {
return `Hello, ${name}!`;
}
// Rest parameters
function sum(...numbers) {
return numbers.reduce((acc, val) => acc + val, 0);
}
sum(1, 2, 3, 4, 5); // 15
Function Expressions
// Anonymous function
const greet = function(name) {
return `Hello, ${name}!`;
};
// Named function expression
const factorial = function fact(n) {
return n <= 1 ? 1 : n * fact(n - 1);
};
Arrow Functions (ES6)
// Basic arrow function
const greet = (name) => {
return `Hello, ${name}!`;
};
// Implicit return (single expression)
const greet = name => `Hello, ${name}!`;
// No parameters
const sayHello = () => "Hello!";
// Multiple parameters
const add = (a, b) => a + b;
// Arrow functions and 'this'
const person = {
name: "Alice",
greet: function() {
setTimeout(() => {
console.log(`Hello, ${this.name}`); // 'this' refers to person
}, 1000);
}
};
Higher-Order Functions
// Function that returns a function
function multiplier(factor) {
return function(number) {
return number * factor;
};
}
const double = multiplier(2);
console.log(double(5)); // 10
// Function that takes a function as argument
function applyOperation(arr, operation) {
return arr.map(operation);
}
const numbers = [1, 2, 3, 4, 5];
const squared = applyOperation(numbers, x => x * x);
Closures
function createCounter() {
let count = 0;
return {
increment() {
return ++count;
},
decrement() {
return --count;
},
getCount() {
return count;
}
};
}
const counter = createCounter();
counter.increment(); // 1
counter.increment(); // 2
counter.getCount(); // 2
Asynchronous JavaScript
Callbacks
// Traditional callback pattern
function fetchData(callback) {
setTimeout(() => {
callback("Data loaded");
}, 1000);
}
fetchData((data) => {
console.log(data);
});
// Callback hell (pyramid of doom)
getData1((data1) => {
getData2(data1, (data2) => {
getData3(data2, (data3) => {
console.log(data3);
});
});
});
Promises
// Creating a promise
const promise = new Promise((resolve, reject) => {
setTimeout(() => {
const success = true;
if (success) {
resolve("Data loaded");
} else {
reject("Error occurred");
}
}, 1000);
});
// Consuming a promise
promise
.then(data => {
console.log(data);
return "Next data";
})
.then(nextData => {
console.log(nextData);
})
.catch(error => {
console.error(error);
})
.finally(() => {
console.log("Cleanup");
});
// Promise chaining
fetch('https://api.example.com/data')
.then(response => response.json())
.then(data => console.log(data))
.catch(error => console.error(error));
// Promise.all (wait for all promises)
Promise.all([promise1, promise2, promise3])
.then(([result1, result2, result3]) => {
console.log(result1, result2, result3);
});
// Promise.race (first to complete)
Promise.race([promise1, promise2])
.then(result => console.log(result));
// Promise.allSettled (wait for all, regardless of result)
Promise.allSettled([promise1, promise2])
.then(results => console.log(results));
Async/Await (ES2017)
// Async function
async function fetchData() {
try {
const response = await fetch('https://api.example.com/data');
const data = await response.json();
console.log(data);
return data;
} catch (error) {
console.error('Error:', error);
}
}
// Sequential execution
async function sequential() {
const data1 = await fetchData1();
const data2 = await fetchData2(data1);
const data3 = await fetchData3(data2);
return data3;
}
// Parallel execution
async function parallel() {
const [data1, data2, data3] = await Promise.all([
fetchData1(),
fetchData2(),
fetchData3()
]);
return { data1, data2, data3 };
}
// Top-level await (ES2022)
const data = await fetchData();
Classes and OOP
ES6 Classes
// Basic class
class Person {
constructor(name, age) {
this.name = name;
this.age = age;
}
greet() {
return `Hello, I'm ${this.name}`;
}
// Static method
static species() {
return "Homo sapiens";
}
// Getter
get info() {
return `${this.name}, ${this.age}`;
}
// Setter
set info(value) {
const [name, age] = value.split(', ');
this.name = name;
this.age = parseInt(age);
}
}
const person = new Person("Alice", 30);
console.log(person.greet());
console.log(Person.species());
Inheritance
class Animal {
constructor(name) {
this.name = name;
}
speak() {
return `${this.name} makes a sound`;
}
}
class Dog extends Animal {
constructor(name, breed) {
super(name); // Call parent constructor
this.breed = breed;
}
speak() {
return `${this.name} barks`;
}
fetch() {
return `${this.name} is fetching`;
}
}
const dog = new Dog("Buddy", "Golden Retriever");
console.log(dog.speak()); // "Buddy barks"
console.log(dog.fetch()); // "Buddy is fetching"
Private Fields (ES2022)
class BankAccount {
#balance = 0; // Private field
deposit(amount) {
this.#balance += amount;
}
withdraw(amount) {
if (amount <= this.#balance) {
this.#balance -= amount;
return amount;
}
return 0;
}
getBalance() {
return this.#balance;
}
}
const account = new BankAccount();
account.deposit(100);
console.log(account.getBalance()); // 100
// console.log(account.#balance); // SyntaxError
Common Patterns
Module Pattern
const MyModule = (function() {
// Private variables
let privateVar = "I'm private";
// Private function
function privateFunction() {
return "Private function called";
}
// Public API
return {
publicVar: "I'm public",
publicFunction() {
return privateFunction();
},
getPrivateVar() {
return privateVar;
}
};
})();
console.log(MyModule.publicVar);
console.log(MyModule.publicFunction());
Revealing Module Pattern
const Calculator = (function() {
let result = 0;
function add(x) {
result += x;
return this;
}
function subtract(x) {
result -= x;
return this;
}
function getResult() {
return result;
}
function reset() {
result = 0;
return this;
}
return {
add,
subtract,
getResult,
reset
};
})();
Calculator.add(5).add(3).subtract(2);
console.log(Calculator.getResult()); // 6
Singleton Pattern
const Singleton = (function() {
let instance;
function createInstance() {
return {
name: "Singleton",
getData() {
return "Data from singleton";
}
};
}
return {
getInstance() {
if (!instance) {
instance = createInstance();
}
return instance;
}
};
})();
const instance1 = Singleton.getInstance();
const instance2 = Singleton.getInstance();
console.log(instance1 === instance2); // true
Factory Pattern
class Car {
constructor(options) {
this.doors = options.doors || 4;
this.state = options.state || "brand new";
this.color = options.color || "silver";
}
}
class Truck {
constructor(options) {
this.wheels = options.wheels || 6;
this.state = options.state || "used";
this.color = options.color || "blue";
}
}
class VehicleFactory {
createVehicle(type, options) {
switch(type) {
case 'car':
return new Car(options);
case 'truck':
return new Truck(options);
default:
throw new Error('Unknown vehicle type');
}
}
}
const factory = new VehicleFactory();
const car = factory.createVehicle('car', { color: 'red' });
const truck = factory.createVehicle('truck', { wheels: 8 });
Observer Pattern
class Subject {
constructor() {
this.observers = [];
}
subscribe(observer) {
this.observers.push(observer);
}
unsubscribe(observer) {
this.observers = this.observers.filter(obs => obs !== observer);
}
notify(data) {
this.observers.forEach(observer => observer.update(data));
}
}
class Observer {
constructor(name) {
this.name = name;
}
update(data) {
console.log(`${this.name} received: ${data}`);
}
}
const subject = new Subject();
const observer1 = new Observer('Observer 1');
const observer2 = new Observer('Observer 2');
subject.subscribe(observer1);
subject.subscribe(observer2);
subject.notify('Event occurred!');
DOM Manipulation
// Selecting elements
const element = document.getElementById('myId');
const elements = document.getElementsByClassName('myClass');
const element = document.querySelector('.myClass');
const elements = document.querySelectorAll('.myClass');
// Creating elements
const div = document.createElement('div');
div.textContent = 'Hello World';
div.className = 'my-class';
div.id = 'my-id';
// Appending elements
document.body.appendChild(div);
parentElement.insertBefore(newElement, referenceElement);
// Modifying content
element.textContent = 'New text';
element.innerHTML = '<strong>Bold text</strong>';
// Modifying attributes
element.setAttribute('data-id', '123');
element.getAttribute('data-id');
element.removeAttribute('data-id');
// Modifying styles
element.style.color = 'red';
element.style.fontSize = '20px';
// Adding/removing classes
element.classList.add('active');
element.classList.remove('inactive');
element.classList.toggle('visible');
element.classList.contains('active');
// Event listeners
element.addEventListener('click', (event) => {
console.log('Element clicked!', event);
});
element.addEventListener('click', handleClick);
element.removeEventListener('click', handleClick);
// Event delegation
document.addEventListener('click', (event) => {
if (event.target.matches('.my-button')) {
console.log('Button clicked!');
}
});
// Preventing default behavior
form.addEventListener('submit', (event) => {
event.preventDefault();
// Handle form submission
});
ES6+ Features
Destructuring
// Array destructuring
const [a, b, c] = [1, 2, 3];
const [first, , third] = [1, 2, 3];
const [head, ...tail] = [1, 2, 3, 4, 5];
// Object destructuring
const { name, age } = { name: 'Alice', age: 30 };
const { name: personName } = { name: 'Alice' }; // Rename
const { name, age = 25 } = { name: 'Alice' }; // Default value
// Nested destructuring
const { address: { city, country } } = person;
// Function parameter destructuring
function greet({ name, age }) {
return `Hello ${name}, you are ${age} years old`;
}
Spread and Rest Operators
// Spread in arrays
const arr1 = [1, 2, 3];
const arr2 = [...arr1, 4, 5, 6];
// Spread in objects
const obj1 = { a: 1, b: 2 };
const obj2 = { ...obj1, c: 3 };
// Rest in function parameters
function sum(...numbers) {
return numbers.reduce((acc, val) => acc + val, 0);
}
// Rest in destructuring
const [first, ...rest] = [1, 2, 3, 4, 5];
Optional Chaining (ES2020)
const user = {
name: 'Alice',
address: {
city: 'NYC'
}
};
// Without optional chaining
const city = user && user.address && user.address.city;
// With optional chaining
const city = user?.address?.city;
const fn = obj?.method?.(); // Call method if exists
Nullish Coalescing (ES2020)
// Returns right operand when left is null or undefined
const value = null ?? 'default'; // 'default'
const value = undefined ?? 'default'; // 'default'
const value = 0 ?? 'default'; // 0
const value = '' ?? 'default'; // ''
// Compare with || operator
const value = 0 || 'default'; // 'default'
const value = '' || 'default'; // 'default'
Error Handling
// Try-catch
try {
throw new Error('Something went wrong');
} catch (error) {
console.error(error.message);
} finally {
console.log('Cleanup');
}
// Custom errors
class ValidationError extends Error {
constructor(message) {
super(message);
this.name = 'ValidationError';
}
}
try {
throw new ValidationError('Invalid input');
} catch (error) {
if (error instanceof ValidationError) {
console.error('Validation error:', error.message);
} else {
throw error; // Re-throw unknown errors
}
}
// Error handling with async/await
async function fetchData() {
try {
const response = await fetch('/api/data');
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
return await response.json();
} catch (error) {
console.error('Fetch error:', error);
throw error;
}
}
Common Array Methods
const numbers = [1, 2, 3, 4, 5];
// map - transform each element
const doubled = numbers.map(n => n * 2);
// filter - select elements
const evens = numbers.filter(n => n % 2 === 0);
// reduce - aggregate
const sum = numbers.reduce((acc, n) => acc + n, 0);
// find - first matching element
const found = numbers.find(n => n > 3);
// findIndex - index of first match
const index = numbers.findIndex(n => n > 3);
// some - at least one matches
const hasEven = numbers.some(n => n % 2 === 0);
// every - all match
const allPositive = numbers.every(n => n > 0);
// flat - flatten nested arrays
const nested = [1, [2, 3], [4, [5, 6]]];
const flat = nested.flat(2); // [1, 2, 3, 4, 5, 6]
// flatMap - map then flatten
const words = ['hello world', 'foo bar'];
const letters = words.flatMap(w => w.split(' '));
Best Practices
- Use
constby default,letwhen reassignment is needed - Avoid
var- it has function scope and hoisting issues - Use arrow functions for callbacks and short functions
- Use template literals instead of string concatenation
- Use async/await instead of promise chains when possible
- Use destructuring for cleaner code
- Use spread operator for copying arrays/objects
- Handle errors properly with try-catch
- Use strict mode:
'use strict'; - Use meaningful variable names
Common Libraries/Frameworks
- React: UI library
- Vue.js: Progressive framework
- Angular: Full-featured framework
- Express.js: Web server framework (Node.js)
- Lodash: Utility library
- Axios: HTTP client
- Moment.js/Day.js: Date manipulation
- D3.js: Data visualization
TypeScript
TypeScript is a strongly-typed superset of JavaScript developed by Microsoft that compiles to plain JavaScript. It adds optional static typing, classes, interfaces, and other features to JavaScript, making it easier to build and maintain large-scale applications.
Table of Contents
- Why TypeScript?
- Basic Types
- Interfaces
- Type Aliases
- Union and Intersection Types
- Generics
- Classes
- Enums
- Type Assertions
- Type Guards
- Utility Types
- TypeScript with React
- TypeScript with Node.js
- Configuration (tsconfig.json)
- Advanced Types
- Best Practices
Why TypeScript?
Benefits:
- Type Safety: Catch errors at compile-time instead of runtime
- Better IDE Support: Enhanced autocomplete, navigation, and refactoring
- Self-Documenting: Types serve as inline documentation
- Scalability: Easier to maintain large codebases
- Modern JavaScript: Use latest JavaScript features with backward compatibility
- Refactoring Confidence: Safe refactoring with type checking
When to Use:
- Large-scale applications
- Team projects with multiple developers
- Projects requiring long-term maintenance
- When you need robust IDE support
- Enterprise applications
Basic Types
Primitive Types
// Boolean
let isDone: boolean = false;
// Number
let decimal: number = 6;
let hex: number = 0xf00d;
let binary: number = 0b1010;
let octal: number = 0o744;
// String
let color: string = "blue";
let fullName: string = `Bob Bobbington`;
let sentence: string = `Hello, my name is ${fullName}.`;
// Array
let list: number[] = [1, 2, 3];
let list2: Array<number> = [1, 2, 3]; // Generic syntax
// Tuple - fixed-length array with known types
let x: [string, number];
x = ["hello", 10]; // OK
// x = [10, "hello"]; // Error
// Enum
enum Color {
Red,
Green,
Blue,
}
let c: Color = Color.Green;
// Any - opt-out of type checking
let notSure: any = 4;
notSure = "maybe a string instead";
notSure = false; // OK
// Unknown - type-safe alternative to any
let userInput: unknown;
userInput = 5;
userInput = "hello";
// let str: string = userInput; // Error
if (typeof userInput === "string") {
let str: string = userInput; // OK
}
// Void - absence of any type (typically for functions)
function warnUser(): void {
console.log("This is a warning message");
}
// Null and Undefined
let u: undefined = undefined;
let n: null = null;
// Never - represents values that never occur
function error(message: string): never {
throw new Error(message);
}
function infiniteLoop(): never {
while (true) {}
}
Interfaces
Interfaces define the structure of objects and enforce contracts in your code.
Basic Interface
interface User {
id: number;
name: string;
email: string;
age?: number; // Optional property
readonly createdAt: Date; // Read-only property
}
const user: User = {
id: 1,
name: "John Doe",
email: "john@example.com",
createdAt: new Date(),
};
// user.createdAt = new Date(); // Error: Cannot assign to 'createdAt'
Function Types
interface SearchFunc {
(source: string, subString: string): boolean;
}
const mySearch: SearchFunc = (source, subString) => {
return source.includes(subString);
};
Indexable Types
interface StringArray {
[index: number]: string;
}
let myArray: StringArray = ["Bob", "Fred"];
let myStr: string = myArray[0];
interface NumberDictionary {
[key: string]: number;
}
let dict: NumberDictionary = {
age: 25,
height: 180,
};
Extending Interfaces
interface Shape {
color: string;
}
interface Square extends Shape {
sideLength: number;
}
let square: Square = {
color: "blue",
sideLength: 10,
};
// Multiple inheritance
interface PenStroke {
penWidth: number;
}
interface FilledSquare extends Square, PenStroke {
filled: boolean;
}
Implementing Interfaces
interface ClockInterface {
currentTime: Date;
setTime(d: Date): void;
}
class Clock implements ClockInterface {
currentTime: Date = new Date();
setTime(d: Date): void {
this.currentTime = d;
}
}
Type Aliases
Type aliases create a new name for a type. Similar to interfaces but more flexible.
// Basic type alias
type ID = string | number;
type Point = {
x: number;
y: number;
};
// Union type
type Result = Success | Failure;
type Success = {
status: "success";
data: any;
};
type Failure = {
status: "error";
error: string;
};
// Function type
type GreetFunction = (name: string) => string;
const greet: GreetFunction = (name) => `Hello, ${name}!`;
// Intersection type
type Admin = {
privileges: string[];
};
type Employee = {
name: string;
startDate: Date;
};
type AdminEmployee = Admin & Employee;
const ae: AdminEmployee = {
privileges: ["create-server"],
name: "Max",
startDate: new Date(),
};
Interface vs Type Alias
// Interfaces can be merged (declaration merging)
interface Window {
title: string;
}
interface Window {
ts: number;
}
// Type aliases cannot be merged
// type Window = { title: string };
// type Window = { ts: number }; // Error: Duplicate identifier
// Type aliases can represent unions and tuples
type StringOrNumber = string | number;
type Tuple = [string, number];
// Both can be extended
interface Shape {
color: string;
}
// Interface extending interface
interface Circle extends Shape {
radius: number;
}
// Type extending type
type ColoredShape = Shape & { filled: boolean };
// Type extending interface
type ColoredCircle = Circle & { filled: boolean };
// Interface extending type
type Size = { width: number; height: number };
interface Rectangle extends Size {
color: string;
}
Union and Intersection Types
Union Types
A union type can be one of several types.
function printId(id: number | string) {
console.log("Your ID is: " + id);
}
printId(101); // OK
printId("202"); // OK
// printId({ myID: 22342 }); // Error
// Discriminated Unions (Tagged Unions)
type Shape =
| { kind: "circle"; radius: number }
| { kind: "square"; sideLength: number }
| { kind: "rectangle"; width: number; height: number };
function getArea(shape: Shape): number {
switch (shape.kind) {
case "circle":
return Math.PI * shape.radius ** 2;
case "square":
return shape.sideLength ** 2;
case "rectangle":
return shape.width * shape.height;
}
}
Intersection Types
An intersection type combines multiple types into one.
interface Colorful {
color: string;
}
interface Circle {
radius: number;
}
type ColorfulCircle = Colorful & Circle;
const cc: ColorfulCircle = {
color: "red",
radius: 42,
};
Generics
Generics allow you to create reusable components that work with multiple types.
Generic Functions
function identity<T>(arg: T): T {
return arg;
}
let output1 = identity<string>("myString");
let output2 = identity<number>(123);
let output3 = identity("myString"); // Type inference
// Generic with constraints
interface Lengthwise {
length: number;
}
function loggingIdentity<T extends Lengthwise>(arg: T): T {
console.log(arg.length);
return arg;
}
loggingIdentity({ length: 10, value: 3 }); // OK
loggingIdentity([1, 2, 3]); // OK
// loggingIdentity(3); // Error: number doesn't have length
Generic Interfaces
interface GenericIdentityFn<T> {
(arg: T): T;
}
let myIdentity: GenericIdentityFn<number> = identity;
// Generic container
interface Container<T> {
value: T;
getValue(): T;
setValue(value: T): void;
}
class Box<T> implements Container<T> {
constructor(public value: T) {}
getValue(): T {
return this.value;
}
setValue(value: T): void {
this.value = value;
}
}
const numberBox = new Box<number>(42);
const stringBox = new Box<string>("hello");
Generic Classes
class GenericNumber<T> {
zeroValue: T;
add: (x: T, y: T) => T;
}
let myGenericNumber = new GenericNumber<number>();
myGenericNumber.zeroValue = 0;
myGenericNumber.add = (x, y) => x + y;
let stringNumeric = new GenericNumber<string>();
stringNumeric.zeroValue = "";
stringNumeric.add = (x, y) => x + y;
Generic Constraints
function getProperty<T, K extends keyof T>(obj: T, key: K): T[K] {
return obj[key];
}
let x = { a: 1, b: 2, c: 3, d: 4 };
getProperty(x, "a"); // OK
// getProperty(x, "m"); // Error: "m" is not in 'a' | 'b' | 'c' | 'd'
Advanced Generic Patterns
// Generic type with default
type APIResponse<T = any> = {
data: T;
status: number;
message: string;
};
// Multiple type parameters
function merge<T, U>(obj1: T, obj2: U): T & U {
return { ...obj1, ...obj2 };
}
const merged = merge({ name: "John" }, { age: 30 });
// merged: { name: string } & { age: number }
// Conditional types with generics
type NonNullable<T> = T extends null | undefined ? never : T;
type A = NonNullable<string | null>; // string
type B = NonNullable<number | undefined>; // number
Classes
Basic Class
class Greeter {
greeting: string;
constructor(message: string) {
this.greeting = message;
}
greet(): string {
return `Hello, ${this.greeting}`;
}
}
let greeter = new Greeter("world");
Inheritance
class Animal {
name: string;
constructor(name: string) {
this.name = name;
}
move(distanceInMeters: number = 0): void {
console.log(`${this.name} moved ${distanceInMeters}m.`);
}
}
class Dog extends Animal {
bark(): void {
console.log("Woof! Woof!");
}
}
const dog = new Dog("Buddy");
dog.bark();
dog.move(10);
Access Modifiers
class Person {
public name: string; // Public by default
private age: number; // Only accessible within the class
protected email: string; // Accessible in class and subclasses
readonly id: number; // Cannot be modified after initialization
constructor(name: string, age: number, email: string, id: number) {
this.name = name;
this.age = age;
this.email = email;
this.id = id;
}
getAge(): number {
return this.age;
}
}
class Employee extends Person {
constructor(name: string, age: number, email: string, id: number) {
super(name, age, email, id);
}
getEmail(): string {
return this.email; // OK: protected is accessible in subclass
}
}
const person = new Person("John", 30, "john@example.com", 1);
console.log(person.name); // OK
// console.log(person.age); // Error: private
// console.log(person.email); // Error: protected
Getters and Setters
class Employee {
private _fullName: string = "";
get fullName(): string {
return this._fullName;
}
set fullName(newName: string) {
if (newName && newName.length > 0) {
this._fullName = newName;
} else {
throw new Error("Invalid name");
}
}
}
let employee = new Employee();
employee.fullName = "Bob Smith";
console.log(employee.fullName);
Abstract Classes
abstract class Department {
constructor(public name: string) {}
printName(): void {
console.log("Department name: " + this.name);
}
abstract printMeeting(): void; // Must be implemented in derived class
}
class AccountingDepartment extends Department {
constructor() {
super("Accounting and Auditing");
}
printMeeting(): void {
console.log("The Accounting Department meets each Monday at 10am.");
}
generateReports(): void {
console.log("Generating accounting reports...");
}
}
let department: Department = new AccountingDepartment();
department.printName();
department.printMeeting();
// department.generateReports(); // Error: method doesn't exist on Department
Static Members
class Grid {
static origin = { x: 0, y: 0 };
calculateDistanceFromOrigin(point: { x: number; y: number }): number {
let xDist = point.x - Grid.origin.x;
let yDist = point.y - Grid.origin.y;
return Math.sqrt(xDist * xDist + yDist * yDist);
}
}
console.log(Grid.origin);
let grid = new Grid();
Enums
Enums allow defining a set of named constants.
Numeric Enums
enum Direction {
Up = 1,
Down,
Left,
Right,
}
// Starts from 1 and auto-increments
console.log(Direction.Up); // 1
console.log(Direction.Down); // 2
enum Response {
No = 0,
Yes = 1,
}
function respond(recipient: string, message: Response): void {
// ...
}
respond("Princess Caroline", Response.Yes);
String Enums
enum Direction {
Up = "UP",
Down = "DOWN",
Left = "LEFT",
Right = "RIGHT",
}
console.log(Direction.Up); // "UP"
Const Enums
const enum Enum {
A = 1,
B = A * 2,
}
// Compiled code is inlined (better performance)
let value = Enum.B; // Becomes: let value = 2;
Enum as Type
enum Status {
Active,
Inactive,
Pending,
}
interface User {
name: string;
status: Status;
}
const user: User = {
name: "John",
status: Status.Active,
};
Type Assertions
Type assertions tell the compiler to treat a value as a specific type.
// Angle-bracket syntax
let someValue: any = "this is a string";
let strLength: number = (<string>someValue).length;
// As syntax (preferred in JSX/TSX)
let someValue2: any = "this is a string";
let strLength2: number = (someValue2 as string).length;
// Non-null assertion operator
function liveDangerously(x?: number | null) {
// TypeScript will trust that x is not null/undefined
console.log(x!.toFixed());
}
// Const assertions
let x = "hello" as const; // Type: "hello" (not string)
let y = [10, 20] as const; // Type: readonly [10, 20]
let z = {
name: "John",
age: 30,
} as const; // All properties are readonly
Type Guards
Type guards allow you to narrow down the type of a variable within a conditional block.
typeof Guards
function padLeft(value: string, padding: string | number) {
if (typeof padding === "number") {
return Array(padding + 1).join(" ") + value;
}
if (typeof padding === "string") {
return padding + value;
}
throw new Error(`Expected string or number, got '${typeof padding}'.`);
}
instanceof Guards
class Bird {
fly() {
console.log("Flying");
}
}
class Fish {
swim() {
console.log("Swimming");
}
}
function move(animal: Bird | Fish) {
if (animal instanceof Bird) {
animal.fly();
} else {
animal.swim();
}
}
in Operator
type Fish = { swim: () => void };
type Bird = { fly: () => void };
function move(animal: Fish | Bird) {
if ("swim" in animal) {
animal.swim();
} else {
animal.fly();
}
}
Custom Type Guards
interface Cat {
meow(): void;
}
interface Dog {
bark(): void;
}
function isCat(pet: Cat | Dog): pet is Cat {
return (pet as Cat).meow !== undefined;
}
function makeSound(pet: Cat | Dog) {
if (isCat(pet)) {
pet.meow();
} else {
pet.bark();
}
}
Utility Types
TypeScript provides several utility types for common type transformations.
Partial
Makes all properties optional.
interface User {
id: number;
name: string;
email: string;
}
function updateUser(user: User, updates: Partial<User>): User {
return { ...user, ...updates };
}
const user: User = { id: 1, name: "John", email: "john@example.com" };
const updated = updateUser(user, { name: "Jane" });
Required
Makes all properties required.
interface Props {
a?: number;
b?: string;
}
const obj: Required<Props> = { a: 5, b: "text" };
Readonly
Makes all properties readonly.
interface User {
name: string;
age: number;
}
const user: Readonly<User> = {
name: "John",
age: 30,
};
// user.name = "Jane"; // Error: Cannot assign to 'name'
Pick<T, K>
Creates a type by picking specific properties from another type.
interface User {
id: number;
name: string;
email: string;
age: number;
}
type UserPreview = Pick<User, "id" | "name">;
// { id: number; name: string; }
const preview: UserPreview = { id: 1, name: "John" };
Omit<T, K>
Creates a type by omitting specific properties.
interface User {
id: number;
name: string;
email: string;
password: string;
}
type UserPublic = Omit<User, "password">;
// { id: number; name: string; email: string; }
Record<K, T>
Creates an object type with keys of type K and values of type T.
type PageInfo = {
title: string;
url: string;
};
type Page = "home" | "about" | "contact";
const pages: Record<Page, PageInfo> = {
home: { title: "Home", url: "/" },
about: { title: "About", url: "/about" },
contact: { title: "Contact", url: "/contact" },
};
Exclude<T, U> and Extract<T, U>
type T0 = Exclude<"a" | "b" | "c", "a">; // "b" | "c"
type T1 = Exclude<string | number | (() => void), Function>; // string | number
type T2 = Extract<"a" | "b" | "c", "a" | "f">; // "a"
type T3 = Extract<string | number | (() => void), Function>; // () => void
ReturnType
Extracts the return type of a function type.
function getUser() {
return { id: 1, name: "John", email: "john@example.com" };
}
type User = ReturnType<typeof getUser>;
// { id: number; name: string; email: string; }
Parameters
Extracts parameter types of a function type as a tuple.
function createUser(name: string, age: number, email: string) {
return { name, age, email };
}
type CreateUserParams = Parameters<typeof createUser>;
// [name: string, age: number, email: string]
TypeScript with React
Functional Components
import React from "react";
// Props interface
interface ButtonProps {
label: string;
onClick: () => void;
disabled?: boolean;
variant?: "primary" | "secondary";
}
// Functional component
const Button: React.FC<ButtonProps> = ({
label,
onClick,
disabled = false,
variant = "primary",
}) => {
return (
<button onClick={onClick} disabled={disabled} className={variant}>
{label}
</button>
);
};
// Alternative (recommended in modern React)
function Button2(props: ButtonProps) {
return <button {...props}>{props.label}</button>;
}
export default Button;
Component with Children
interface CardProps {
title: string;
children: React.ReactNode;
}
const Card: React.FC<CardProps> = ({ title, children }) => {
return (
<div className="card">
<h2>{title}</h2>
<div className="card-body">{children}</div>
</div>
);
};
useState Hook
import { useState } from "react";
interface User {
id: number;
name: string;
}
function UserComponent() {
// Type inference
const [count, setCount] = useState(0);
// Explicit type
const [user, setUser] = useState<User | null>(null);
// With initial state
const [users, setUsers] = useState<User[]>([]);
return (
<div>
<p>Count: {count}</p>
<button onClick={() => setCount(count + 1)}>Increment</button>
</div>
);
}
useEffect Hook
import { useEffect, useState } from "react";
function DataFetcher() {
const [data, setData] = useState<any>(null);
useEffect(() => {
async function fetchData() {
const response = await fetch("/api/data");
const result = await response.json();
setData(result);
}
fetchData();
}, []); // Empty dependency array
return <div>{data ? JSON.stringify(data) : "Loading..."}</div>;
}
useRef Hook
import { useRef, useEffect } from "react";
function TextInput() {
const inputRef = useRef<HTMLInputElement>(null);
useEffect(() => {
inputRef.current?.focus();
}, []);
return <input ref={inputRef} type="text" />;
}
useContext Hook
import { createContext, useContext, useState } from "react";
interface AuthContextType {
user: User | null;
login: (user: User) => void;
logout: () => void;
}
const AuthContext = createContext<AuthContextType | undefined>(undefined);
export const AuthProvider: React.FC<{ children: React.ReactNode }> = ({
children,
}) => {
const [user, setUser] = useState<User | null>(null);
const login = (user: User) => setUser(user);
const logout = () => setUser(null);
return (
<AuthContext.Provider value={{ user, login, logout }}>
{children}
</AuthContext.Provider>
);
};
export const useAuth = () => {
const context = useContext(AuthContext);
if (context === undefined) {
throw new Error("useAuth must be used within AuthProvider");
}
return context;
};
Custom Hooks
import { useState, useEffect } from "react";
function useFetch<T>(url: string) {
const [data, setData] = useState<T | null>(null);
const [loading, setLoading] = useState(true);
const [error, setError] = useState<Error | null>(null);
useEffect(() => {
async function fetchData() {
try {
const response = await fetch(url);
const result = await response.json();
setData(result);
} catch (err) {
setError(err as Error);
} finally {
setLoading(false);
}
}
fetchData();
}, [url]);
return { data, loading, error };
}
// Usage
interface User {
id: number;
name: string;
}
function UserList() {
const { data: users, loading, error } = useFetch<User[]>("/api/users");
if (loading) return <div>Loading...</div>;
if (error) return <div>Error: {error.message}</div>;
return (
<ul>
{users?.map((user) => (
<li key={user.id}>{user.name}</li>
))}
</ul>
);
}
Event Handlers
import React from "react";
function Form() {
const handleSubmit = (e: React.FormEvent<HTMLFormElement>) => {
e.preventDefault();
// Handle form submission
};
const handleChange = (e: React.ChangeEvent<HTMLInputElement>) => {
console.log(e.target.value);
};
const handleClick = (e: React.MouseEvent<HTMLButtonElement>) => {
console.log("Button clicked");
};
return (
<form onSubmit={handleSubmit}>
<input type="text" onChange={handleChange} />
<button onClick={handleClick}>Submit</button>
</form>
);
}
TypeScript with Node.js
Basic Express Server
import express, { Request, Response, NextFunction } from "express";
const app = express();
const PORT = 3000;
app.use(express.json());
// Basic route
app.get("/", (req: Request, res: Response) => {
res.json({ message: "Hello World" });
});
// Route with params
app.get("/users/:id", (req: Request, res: Response) => {
const userId = req.params.id;
res.json({ id: userId });
});
// POST route with body
interface CreateUserBody {
name: string;
email: string;
}
app.post("/users", (req: Request<{}, {}, CreateUserBody>, res: Response) => {
const { name, email } = req.body;
res.json({ id: 1, name, email });
});
// Middleware
const logger = (req: Request, res: Response, next: NextFunction) => {
console.log(`${req.method} ${req.path}`);
next();
};
app.use(logger);
// Error handling middleware
app.use((err: Error, req: Request, res: Response, next: NextFunction) => {
console.error(err.stack);
res.status(500).json({ error: err.message });
});
app.listen(PORT, () => {
console.log(`Server running on port ${PORT}`);
});
Custom Request Types
import { Request } from "express";
interface UserRequest extends Request {
user?: {
id: number;
email: string;
};
}
app.get("/profile", (req: UserRequest, res: Response) => {
if (!req.user) {
return res.status(401).json({ error: "Unauthorized" });
}
res.json(req.user);
});
Async/Await with Express
import { Request, Response } from "express";
// Wrapper for async route handlers
const asyncHandler =
(fn: Function) => (req: Request, res: Response, next: NextFunction) => {
Promise.resolve(fn(req, res, next)).catch(next);
};
app.get(
"/users",
asyncHandler(async (req: Request, res: Response) => {
const users = await getUsersFromDB();
res.json(users);
})
);
File System Operations
import * as fs from "fs/promises";
import * as path from "path";
async function readConfig(): Promise<any> {
try {
const configPath = path.join(__dirname, "config.json");
const data = await fs.readFile(configPath, "utf-8");
return JSON.parse(data);
} catch (error) {
console.error("Error reading config:", error);
throw error;
}
}
async function writeLog(message: string): Promise<void> {
const logPath = path.join(__dirname, "app.log");
const timestamp = new Date().toISOString();
const logEntry = `[${timestamp}] ${message}\n`;
await fs.appendFile(logPath, logEntry);
}
Configuration (tsconfig.json)
Basic Configuration
{
"compilerOptions": {
"target": "ES2020",
"module": "commonjs",
"lib": ["ES2020"],
"outDir": "./dist",
"rootDir": "./src",
"strict": true,
"esModuleInterop": true,
"skipLibCheck": true,
"forceConsistentCasingInFileNames": true,
"resolveJsonModule": true,
"moduleResolution": "node",
"declaration": true,
"declarationMap": true,
"sourceMap": true
},
"include": ["src/**/*"],
"exclude": ["node_modules", "dist"]
}
Important Compiler Options
Strict Type Checking
{
"compilerOptions": {
"strict": true, // Enable all strict type checking
"noImplicitAny": true, // Error on expressions with implied 'any'
"strictNullChecks": true, // Enable strict null checks
"strictFunctionTypes": true, // Enable strict checking of function types
"strictBindCallApply": true, // Enable strict bind/call/apply methods
"strictPropertyInitialization": true, // Ensure properties are initialized
"noImplicitThis": true, // Error on 'this' expressions with implied 'any'
"alwaysStrict": true // Parse in strict mode and emit "use strict"
}
}
Module Resolution
{
"compilerOptions": {
"module": "commonjs", // Module code generation
"moduleResolution": "node", // Module resolution strategy
"baseUrl": "./", // Base directory for module resolution
"paths": { // Path mappings
"@/*": ["src/*"],
"@components/*": ["src/components/*"]
},
"esModuleInterop": true, // Emit helpers for importing CommonJS modules
"allowSyntheticDefaultImports": true // Allow default imports from modules
}
}
React Configuration
{
"compilerOptions": {
"jsx": "react-jsx", // JSX code generation (React 17+)
// "jsx": "react", // For React 16 and earlier
"lib": ["DOM", "DOM.Iterable", "ES2020"]
}
}
Node.js Configuration
{
"compilerOptions": {
"target": "ES2020",
"module": "commonjs",
"lib": ["ES2020"],
"types": ["node"],
"esModuleInterop": true
}
}
Project References
For monorepos or multi-package projects:
{
"compilerOptions": {
"composite": true,
"declaration": true,
"declarationMap": true
},
"references": [
{ "path": "../common" },
{ "path": "../utils" }
]
}
Advanced Types
Conditional Types
type IsString<T> = T extends string ? true : false;
type A = IsString<string>; // true
type B = IsString<number>; // false
// Distributive conditional types
type ToArray<T> = T extends any ? T[] : never;
type StrOrNumArray = ToArray<string | number>; // string[] | number[]
// Infer keyword
type ReturnType<T> = T extends (...args: any[]) => infer R ? R : never;
type Func = () => number;
type Result = ReturnType<Func>; // number
Mapped Types
type Readonly<T> = {
readonly [P in keyof T]: T[P];
};
type Optional<T> = {
[P in keyof T]?: T[P];
};
type Nullable<T> = {
[P in keyof T]: T[P] | null;
};
interface Person {
name: string;
age: number;
}
type ReadonlyPerson = Readonly<Person>;
// { readonly name: string; readonly age: number; }
type OptionalPerson = Optional<Person>;
// { name?: string; age?: number; }
Template Literal Types
type EventName = "click" | "scroll" | "mousemove";
type Handler = `on${Capitalize<EventName>}`;
// "onClick" | "onScroll" | "onMousemove"
type PropEventSource<Type> = {
on<Key extends string & keyof Type>(
eventName: `${Key}Changed`,
callback: (newValue: Type[Key]) => void
): void;
};
declare function makeWatchedObject<Type>(
obj: Type
): Type & PropEventSource<Type>;
const person = makeWatchedObject({
firstName: "John",
age: 26,
});
person.on("firstNameChanged", (newName) => {
console.log(`New name: ${newName}`);
});
Index Signatures
interface StringArray {
[index: number]: string;
}
interface StringByString {
[key: string]: string | number;
length: number; // OK: number is assignable to string | number
}
// Generic index signature
interface Dictionary<T> {
[key: string]: T;
}
const userScores: Dictionary<number> = {
john: 100,
jane: 95,
};
Discriminated Unions (Tagged Unions)
interface Square {
kind: "square";
size: number;
}
interface Rectangle {
kind: "rectangle";
width: number;
height: number;
}
interface Circle {
kind: "circle";
radius: number;
}
type Shape = Square | Rectangle | Circle;
function area(s: Shape): number {
switch (s.kind) {
case "square":
return s.size * s.size;
case "rectangle":
return s.width * s.height;
case "circle":
return Math.PI * s.radius ** 2;
}
}
Best Practices
1. Use Strict Mode
Always enable strict: true in tsconfig.json for maximum type safety.
{
"compilerOptions": {
"strict": true
}
}
2. Avoid any Type
Use unknown instead of any when the type is truly unknown.
// Bad
function process(data: any) {
return data.value;
}
// Good
function process(data: unknown) {
if (typeof data === "object" && data !== null && "value" in data) {
return (data as { value: any }).value;
}
throw new Error("Invalid data");
}
3. Use Type Inference
Let TypeScript infer types when possible.
// Bad
const numbers: number[] = [1, 2, 3];
const result: number = numbers.reduce((acc: number, n: number) => acc + n, 0);
// Good
const numbers = [1, 2, 3];
const result = numbers.reduce((acc, n) => acc + n, 0);
4. Use Readonly When Appropriate
// Readonly arrays
const numbers: readonly number[] = [1, 2, 3];
// numbers.push(4); // Error
// Readonly objects
interface Config {
readonly apiUrl: string;
readonly timeout: number;
}
// Readonly function parameters
function printList(list: readonly string[]) {
// list.push("new"); // Error
console.log(list.join(", "));
}
5. Prefer Interfaces for Objects, Types for Unions/Intersections
// Good: Use interface for object shapes
interface User {
id: number;
name: string;
}
// Good: Use type for unions
type Status = "pending" | "approved" | "rejected";
// Good: Use type for intersections
type AdminUser = User & { role: "admin" };
6. Use Discriminated Unions for Complex State
type RequestState =
| { status: "idle" }
| { status: "loading" }
| { status: "success"; data: any }
| { status: "error"; error: string };
function handleRequest(state: RequestState) {
switch (state.status) {
case "idle":
return "Not started";
case "loading":
return "Loading...";
case "success":
return state.data;
case "error":
return state.error;
}
}
7. Use Const Assertions
// Without const assertion
const colors = ["red", "green", "blue"];
// Type: string[]
// With const assertion
const colors = ["red", "green", "blue"] as const;
// Type: readonly ["red", "green", "blue"]
const config = {
apiUrl: "https://api.example.com",
timeout: 5000,
} as const;
// All properties are readonly
8. Use Type Guards
function isString(value: unknown): value is string {
return typeof value === "string";
}
function processValue(value: string | number) {
if (isString(value)) {
console.log(value.toUpperCase());
} else {
console.log(value.toFixed(2));
}
}
9. Use Generics for Reusable Code
// Generic function
function firstOrNull<T>(arr: T[]): T | null {
return arr.length > 0 ? arr[0] : null;
}
// Generic constraints
function getProperty<T, K extends keyof T>(obj: T, key: K): T[K] {
return obj[key];
}
10. Use Utility Types
// Instead of manually creating partial types
interface User {
id: number;
name: string;
email: string;
}
// Good: Use Partial utility type
function updateUser(id: number, updates: Partial<User>) {
// Implementation
}
// Good: Use Pick for selecting specific properties
type UserPreview = Pick<User, "id" | "name">;
// Good: Use Omit to exclude properties
type UserWithoutId = Omit<User, "id">;
11. Avoid Type Assertions When Possible
// Bad
const data = JSON.parse(jsonString) as User;
// Good: Validate at runtime
function isUser(data: any): data is User {
return (
typeof data === "object" &&
typeof data.id === "number" &&
typeof data.name === "string"
);
}
const data = JSON.parse(jsonString);
if (isUser(data)) {
// TypeScript knows data is User here
}
12. Use Enum Alternatives
// Instead of enum
enum Status {
Pending,
Approved,
Rejected,
}
// Consider union types
type Status = "pending" | "approved" | "rejected";
// Or const objects with 'as const'
const Status = {
Pending: "pending",
Approved: "approved",
Rejected: "rejected",
} as const;
type StatusValue = (typeof Status)[keyof typeof Status];
13. Document Complex Types
/**
* Represents a user in the system
* @property id - Unique identifier
* @property name - Full name of the user
* @property email - Contact email address
*/
interface User {
id: number;
name: string;
email: string;
}
/**
* Fetches user data from the API
* @param userId - The ID of the user to fetch
* @returns Promise resolving to user data
* @throws {Error} When user is not found
*/
async function fetchUser(userId: number): Promise<User> {
// Implementation
}
14. Use Namespace for Organization (Sparingly)
namespace Validation {
export interface StringValidator {
isValid(s: string): boolean;
}
export class EmailValidator implements StringValidator {
isValid(s: string): boolean {
return /^[\w-\.]+@([\w-]+\.)+[\w-]{2,4}$/.test(s);
}
}
}
const validator = new Validation.EmailValidator();
15. Leverage TSConfig Paths
{
"compilerOptions": {
"baseUrl": ".",
"paths": {
"@components/*": ["src/components/*"],
"@utils/*": ["src/utils/*"],
"@models/*": ["src/models/*"]
}
}
}
// Instead of
import { Button } from "../../../components/Button";
// Use
import { Button } from "@components/Button";
Common Patterns
Factory Pattern
interface Product {
operation(): string;
}
class ConcreteProductA implements Product {
operation(): string {
return "Product A";
}
}
class ConcreteProductB implements Product {
operation(): string {
return "Product B";
}
}
class ProductFactory {
createProduct(type: "A" | "B"): Product {
switch (type) {
case "A":
return new ConcreteProductA();
case "B":
return new ConcreteProductB();
}
}
}
Builder Pattern
class QueryBuilder {
private query: string = "";
select(...fields: string[]): this {
this.query += `SELECT ${fields.join(", ")} `;
return this;
}
from(table: string): this {
this.query += `FROM ${table} `;
return this;
}
where(condition: string): this {
this.query += `WHERE ${condition} `;
return this;
}
build(): string {
return this.query.trim();
}
}
const query = new QueryBuilder()
.select("id", "name")
.from("users")
.where("age > 18")
.build();
Singleton Pattern
class Database {
private static instance: Database;
private connection: any;
private constructor() {
// Private constructor prevents instantiation
this.connection = this.connect();
}
private connect() {
// Connection logic
return {};
}
static getInstance(): Database {
if (!Database.instance) {
Database.instance = new Database();
}
return Database.instance;
}
query(sql: string) {
// Query logic
}
}
const db1 = Database.getInstance();
const db2 = Database.getInstance();
console.log(db1 === db2); // true
Resources
- Official Documentation: https://www.typescriptlang.org/docs/
- TypeScript Playground: https://www.typescriptlang.org/play
- Definitely Typed: https://github.com/DefinitelyTyped/DefinitelyTyped
- TypeScript Deep Dive: https://basarat.gitbook.io/typescript/
- React TypeScript Cheatsheet: https://react-typescript-cheatsheet.netlify.app/
TypeScript significantly improves the development experience by catching errors early, providing better tooling support, and making code more maintainable. The initial learning curve is worth the long-term benefits, especially for large-scale applications and team projects.
Bash Programming
Overview
Bash (Bourne Again SHell) is a Unix shell and command language used for automating tasks, system administration, and scripting. It's the default shell on most Linux distributions and macOS.
Key Features:
- Command execution and scripting
- Text processing and file manipulation
- Process control and job management
- Environment variable management
- Piping and redirection
- Pattern matching and globbing
Basic Syntax
Variables
# Variable assignment (no spaces around =)
name="Alice"
age=30
readonly PI=3.14159 # Read-only variable
# Accessing variables
echo "Hello, $name"
echo "Hello, ${name}!" # Recommended for clarity
# Command substitution
current_date=$(date)
current_dir=`pwd` # Old style, avoid
# Default values
echo "${var:-default}" # Use default if var is unset
echo "${var:=default}" # Set var to default if unset
echo "${var:+alternate}" # Use alternate if var is set
echo "${var:?error message}" # Error if var is unset
# String length
name="Alice"
echo "${#name}" # 5
# Substring
echo "${name:0:3}" # Ali
Data Types
# Strings
str="Hello World"
str='Single quotes - literal'
str="Double quotes - $variable expansion"
# Arrays
fruits=("apple" "banana" "cherry")
echo "${fruits[0]}" # apple
echo "${fruits[@]}" # All elements
echo "${#fruits[@]}" # Array length
fruits+=("date") # Append
# Associative arrays (Bash 4+)
declare -A person
person[name]="Alice"
person[age]=30
echo "${person[name]}"
# Integers
declare -i num=42
num=$num+10 # Arithmetic
Control Flow
If Statements
# Basic if
if [ "$age" -gt 18 ]; then
echo "Adult"
fi
# If-elif-else
if [ "$age" -lt 13 ]; then
echo "Child"
elif [ "$age" -lt 20 ]; then
echo "Teenager"
else
echo "Adult"
fi
# String comparison
if [ "$name" = "Alice" ]; then
echo "Hello Alice"
fi
if [ "$name" != "Bob" ]; then
echo "Not Bob"
fi
# File tests
if [ -f "file.txt" ]; then
echo "File exists"
fi
if [ -d "directory" ]; then
echo "Directory exists"
fi
if [ -r "file.txt" ]; then
echo "File is readable"
fi
# Logical operators
if [ "$age" -gt 18 ] && [ "$age" -lt 65 ]; then
echo "Working age"
fi
if [ "$age" -lt 18 ] || [ "$age" -gt 65 ]; then
echo "Not working age"
fi
# Modern test syntax [[ ]]
if [[ "$name" == "Alice" ]]; then
echo "Hello Alice"
fi
if [[ "$name" =~ ^A ]]; then # Regex matching
echo "Name starts with A"
fi
Comparison Operators
# Numeric comparison
[ "$a" -eq "$b" ] # Equal
[ "$a" -ne "$b" ] # Not equal
[ "$a" -gt "$b" ] # Greater than
[ "$a" -ge "$b" ] # Greater than or equal
[ "$a" -lt "$b" ] # Less than
[ "$a" -le "$b" ] # Less than or equal
# String comparison
[ "$a" = "$b" ] # Equal
[ "$a" != "$b" ] # Not equal
[ -z "$a" ] # String is empty
[ -n "$a" ] # String is not empty
# File tests
[ -e file ] # Exists
[ -f file ] # Regular file
[ -d file ] # Directory
[ -r file ] # Readable
[ -w file ] # Writable
[ -x file ] # Executable
[ -s file ] # Not empty
[ file1 -nt file2 ] # file1 newer than file2
[ file1 -ot file2 ] # file1 older than file2
Loops
# For loop
for i in 1 2 3 4 5; do
echo "$i"
done
# C-style for loop
for ((i=0; i<5; i++)); do
echo "$i"
done
# For loop with range
for i in {1..10}; do
echo "$i"
done
# For loop with step
for i in {0..10..2}; do
echo "$i" # 0, 2, 4, 6, 8, 10
done
# Iterate over array
fruits=("apple" "banana" "cherry")
for fruit in "${fruits[@]}"; do
echo "$fruit"
done
# Iterate over files
for file in *.txt; do
echo "Processing $file"
done
# While loop
count=0
while [ $count -lt 5 ]; do
echo "$count"
((count++))
done
# Read file line by line
while IFS= read -r line; do
echo "$line"
done < file.txt
# Until loop
count=0
until [ $count -ge 5 ]; do
echo "$count"
((count++))
done
# Break and continue
for i in {1..10}; do
if [ $i -eq 5 ]; then
continue # Skip 5
fi
if [ $i -eq 8 ]; then
break # Stop at 8
fi
echo "$i"
done
Case Statements
case "$1" in
start)
echo "Starting service..."
;;
stop)
echo "Stopping service..."
;;
restart)
echo "Restarting service..."
;;
*)
echo "Usage: $0 {start|stop|restart}"
exit 1
;;
esac
# Pattern matching in case
case "$filename" in
*.txt)
echo "Text file"
;;
*.jpg|*.png)
echo "Image file"
;;
*)
echo "Unknown file type"
;;
esac
Functions
# Basic function
greet() {
echo "Hello, $1!"
}
greet "Alice" # Hello, Alice!
# Function with return value
add() {
local result=$(($1 + $2))
echo "$result"
}
sum=$(add 5 3)
echo "Sum: $sum"
# Function with return code
check_file() {
if [ -f "$1" ]; then
return 0 # Success
else
return 1 # Failure
fi
}
if check_file "file.txt"; then
echo "File exists"
else
echo "File not found"
fi
# Local variables
my_function() {
local local_var="I'm local"
global_var="I'm global"
}
# Function with multiple return values
get_stats() {
local min=1
local max=100
local avg=50
echo "$min $max $avg"
}
read min max avg <<< $(get_stats)
echo "Min: $min, Max: $max, Avg: $avg"
String Manipulation
# Length
str="Hello World"
echo "${#str}" # 11
# Substring
echo "${str:0:5}" # Hello
echo "${str:6}" # World
echo "${str: -5}" # World (note space before -)
# Replace
echo "${str/World/Universe}" # Hello Universe (first occurrence)
echo "${str//o/O}" # HellO WOrld (all occurrences)
# Remove prefix/suffix
filename="example.tar.gz"
echo "${filename#*.}" # tar.gz (remove shortest prefix)
echo "${filename##*.}" # gz (remove longest prefix)
echo "${filename%.*}" # example.tar (remove shortest suffix)
echo "${filename%%.*}" # example (remove longest suffix)
# Upper/Lower case
str="Hello World"
echo "${str^^}" # HELLO WORLD
echo "${str,,}" # hello world
echo "${str^}" # Hello world (first char upper)
# Trim whitespace
str=" hello "
str="${str#"${str%%[![:space:]]*}"}" # Trim left
str="${str%"${str##*[![:space:]]}"}" # Trim right
Input/Output
Reading Input
# Read from user
read -p "Enter your name: " name
echo "Hello, $name!"
# Read with timeout
if read -t 5 -p "Enter value (5s timeout): " value; then
echo "You entered: $value"
else
echo "Timeout!"
fi
# Read password (hidden)
read -s -p "Enter password: " password
echo
# Read multiple values
read -p "Enter name and age: " name age
# Read into array
IFS=',' read -ra array <<< "apple,banana,cherry"
Output
# Echo
echo "Hello World"
echo -n "No newline"
echo -e "Line1\nLine2" # Enable escape sequences
# Printf (more control)
printf "Name: %s, Age: %d\n" "Alice" 30
printf "%.2f\n" 3.14159 # 3.14
# Here document
cat << EOF
This is a
multi-line
message
EOF
# Here string
grep "pattern" <<< "string to search"
Redirection
# Output redirection
echo "Hello" > file.txt # Overwrite
echo "World" >> file.txt # Append
# Input redirection
while read line; do
echo "$line"
done < file.txt
# Error redirection
command 2> error.log # Redirect stderr
command > output.txt 2>&1 # Redirect both stdout and stderr
command &> all_output.txt # Same as above (Bash 4+)
# Discard output
command > /dev/null 2>&1
# Pipe
cat file.txt | grep "pattern" | sort | uniq
# Tee (write to file and stdout)
echo "Hello" | tee file.txt
# Process substitution
diff <(ls dir1) <(ls dir2)
File Operations
# Create file
touch file.txt
echo "content" > file.txt
# Copy
cp source.txt dest.txt
cp -r source_dir/ dest_dir/
# Move/Rename
mv old.txt new.txt
mv file.txt directory/
# Delete
rm file.txt
rm -r directory/
rm -f file.txt # Force delete
# Create directory
mkdir directory
mkdir -p path/to/nested/directory
# Read file
cat file.txt
head -n 10 file.txt # First 10 lines
tail -n 10 file.txt # Last 10 lines
tail -f file.txt # Follow file (live updates)
# File permissions
chmod 755 script.sh # rwxr-xr-x
chmod +x script.sh # Add execute permission
chmod u+x script.sh # User execute
chmod go-w file.txt # Remove write for group and others
# File ownership
chown user:group file.txt
# Find files
find . -name "*.txt"
find . -type f -name "*.log"
find . -mtime -7 # Modified in last 7 days
find . -size +10M # Larger than 10MB
Process Management
# Run in background
command &
# List jobs
jobs
# Bring to foreground
fg %1
# Send to background
bg %1
# Kill process
kill PID
kill -9 PID # Force kill
killall process_name
# Process info
ps aux
ps aux | grep process_name
top
htop
# Exit status
command
echo $? # 0 = success, non-zero = failure
# Conditional execution
command1 && command2 # command2 runs if command1 succeeds
command1 || command2 # command2 runs if command1 fails
command1 ; command2 # command2 runs regardless
# Wait for process
command &
PID=$!
wait $PID
Common Patterns
Error Handling
# Exit on error
set -e # Exit if any command fails
set -u # Exit if undefined variable is used
set -o pipefail # Exit if any command in pipe fails
# Combined
set -euo pipefail
# Error function
error_exit() {
echo "ERROR: $1" >&2
exit 1
}
[ -f "file.txt" ] || error_exit "File not found"
# Trap errors
trap 'echo "Error on line $LINENO"' ERR
# Cleanup on exit
cleanup() {
rm -f /tmp/tempfile
}
trap cleanup EXIT
Argument Parsing
# Positional arguments
echo "Script: $0"
echo "First arg: $1"
echo "Second arg: $2"
echo "All args: $@"
echo "Number of args: $#"
# Shift arguments
while [ $# -gt 0 ]; do
echo "$1"
shift
done
# Parse options
while getopts "a:b:c" opt; do
case $opt in
a)
echo "Option -a: $OPTARG"
;;
b)
echo "Option -b: $OPTARG"
;;
c)
echo "Option -c"
;;
\?)
echo "Invalid option: -$OPTARG"
exit 1
;;
esac
done
# Long options
while [ $# -gt 0 ]; do
case "$1" in
--help)
echo "Usage: $0 [options]"
exit 0
;;
--file=*)
FILE="${1#*=}"
;;
--verbose)
VERBOSE=1
;;
*)
echo "Unknown option: $1"
exit 1
;;
esac
shift
done
Logging
# Simple logging
log() {
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
}
log "Script started"
# Log levels
LOG_LEVEL=${LOG_LEVEL:-INFO}
log_debug() {
[ "$LOG_LEVEL" = "DEBUG" ] && echo "[DEBUG] $1"
}
log_info() {
echo "[INFO] $1"
}
log_error() {
echo "[ERROR] $1" >&2
}
# Log to file
exec > >(tee -a script.log)
exec 2>&1
Configuration Files
# Source configuration
CONFIG_FILE="config.sh"
if [ -f "$CONFIG_FILE" ]; then
source "$CONFIG_FILE"
fi
# config.sh
# DB_HOST="localhost"
# DB_PORT=5432
# DB_NAME="mydb"
# Read key-value pairs
while IFS='=' read -r key value; do
case "$key" in
DB_HOST) DB_HOST="$value" ;;
DB_PORT) DB_PORT="$value" ;;
DB_NAME) DB_NAME="$value" ;;
esac
done < config.txt
Text Processing
# grep (search)
grep "pattern" file.txt
grep -i "pattern" file.txt # Case insensitive
grep -r "pattern" directory/ # Recursive
grep -v "pattern" file.txt # Invert match
grep -n "pattern" file.txt # Show line numbers
grep -c "pattern" file.txt # Count matches
grep -E "regex" file.txt # Extended regex
# sed (stream editor)
sed 's/old/new/' file.txt # Replace first occurrence
sed 's/old/new/g' file.txt # Replace all
sed -i 's/old/new/g' file.txt # In-place edit
sed -n '10,20p' file.txt # Print lines 10-20
sed '/pattern/d' file.txt # Delete matching lines
# awk (text processing)
awk '{print $1}' file.txt # Print first column
awk '{print $1, $3}' file.txt # Print columns 1 and 3
awk -F: '{print $1}' /etc/passwd # Custom delimiter
awk '$3 > 100' file.txt # Filter rows
awk '{sum += $1} END {print sum}' file.txt # Sum column
# cut (extract columns)
cut -d: -f1 /etc/passwd # Field 1, delimiter :
cut -c1-10 file.txt # Characters 1-10
# sort
sort file.txt
sort -r file.txt # Reverse
sort -n file.txt # Numeric sort
sort -k2 file.txt # Sort by column 2
sort -u file.txt # Unique
# uniq (unique lines)
sort file.txt | uniq # Remove duplicates
sort file.txt | uniq -c # Count occurrences
sort file.txt | uniq -d # Only duplicates
# wc (word count)
wc -l file.txt # Line count
wc -w file.txt # Word count
wc -c file.txt # Byte count
# tr (translate characters)
echo "hello" | tr 'a-z' 'A-Z' # HELLO
echo "hello123" | tr -d '0-9' # hello
Common Utilities
# Date and time
date # Current date/time
date '+%Y-%m-%d' # 2024-01-15
date '+%Y-%m-%d %H:%M:%S' # 2024-01-15 14:30:00
date -d "yesterday" # Yesterday's date
date -d "+7 days" # Date 7 days from now
# Arithmetic
echo $((5 + 3)) # 8
echo $((10 / 3)) # 3 (integer division)
echo "scale=2; 10 / 3" | bc # 3.33 (bc for floating point)
# Random numbers
echo $RANDOM # Random number 0-32767
echo $((RANDOM % 100)) # Random 0-99
# Sleep
sleep 5 # Sleep 5 seconds
sleep 0.5 # Sleep 0.5 seconds
# Command existence check
if command -v git &> /dev/null; then
echo "Git is installed"
fi
# Array operations
arr=(1 2 3 4 5)
echo "${arr[@]}" # All elements
echo "${#arr[@]}" # Length
echo "${arr[@]:1:3}" # Slice [1:4]
arr+=(6) # Append
Best Practices
- Always quote variables:
"$var"not$var - Use
set -euo pipefailfor safer scripts - Check command existence before using
- Validate input and arguments
- Use functions for code reuse
- Add comments and documentation
- Use meaningful variable names
- Handle errors explicitly
- Use
[[instead of[for conditions - Avoid parsing
lsoutput - use globbing orfind
Script Template
#!/usr/bin/env bash
# Script: script_name.sh
# Description: What this script does
# Author: Your Name
# Date: 2024-01-15
set -euo pipefail # Exit on error, undefined var, pipe failure
# Constants
readonly SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
readonly SCRIPT_NAME="$(basename "$0")"
# Variables
VERBOSE=0
DRY_RUN=0
# Functions
usage() {
cat << EOF
Usage: $SCRIPT_NAME [OPTIONS]
Description of what the script does.
OPTIONS:
-h, --help Show this help message
-v, --verbose Verbose output
-n, --dry-run Dry run mode
EOF
}
log_info() {
echo "[INFO] $*"
}
log_error() {
echo "[ERROR] $*" >&2
}
cleanup() {
log_info "Cleaning up..."
# Cleanup code here
}
main() {
log_info "Starting script..."
# Main script logic here
log_info "Script completed successfully"
}
# Parse arguments
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
usage
exit 0
;;
-v|--verbose)
VERBOSE=1
shift
;;
-n|--dry-run)
DRY_RUN=1
shift
;;
*)
log_error "Unknown option: $1"
usage
exit 1
;;
esac
done
# Trap cleanup on exit
trap cleanup EXIT
# Run main function
main "$@"
Common Use Cases
Backup Script
#!/bin/bash
BACKUP_DIR="/backup"
SOURCE_DIR="/data"
DATE=$(date +%Y%m%d_%H%M%S)
BACKUP_FILE="backup_${DATE}.tar.gz"
tar -czf "${BACKUP_DIR}/${BACKUP_FILE}" "${SOURCE_DIR}"
echo "Backup created: ${BACKUP_FILE}"
# Keep only last 7 backups
cd "${BACKUP_DIR}"
ls -t backup_*.tar.gz | tail -n +8 | xargs -r rm
System Monitoring
#!/bin/bash
CPU_THRESHOLD=80
DISK_THRESHOLD=90
# Check CPU usage
CPU_USAGE=$(top -bn1 | grep "Cpu(s)" | awk '{print $2}' | cut -d'%' -f1)
if (( $(echo "$CPU_USAGE > $CPU_THRESHOLD" | bc -l) )); then
echo "WARNING: CPU usage is ${CPU_USAGE}%"
fi
# Check disk usage
df -H | grep -vE '^Filesystem|tmpfs|cdrom' | awk '{print $5 " " $1}' | while read output; do
usage=$(echo $output | awk '{print $1}' | sed 's/%//g')
partition=$(echo $output | awk '{print $2}')
if [ $usage -ge $DISK_THRESHOLD ]; then
echo "WARNING: Disk usage on $partition is ${usage}%"
fi
done
Useful Resources
- ShellCheck: Linter for shell scripts
- Bash Manual:
man bash - Bash Guide: https://mywiki.wooledge.org/BashGuide
- Explainshell: Explain shell commands (explainshell.com)
Java Programming
Overview
Java is a high-level, class-based, object-oriented programming language designed to have minimal implementation dependencies. It follows the "write once, run anywhere" (WORA) principle.
Key Features:
- Platform independent (runs on JVM)
- Object-oriented programming
- Automatic memory management (Garbage Collection)
- Strong type system
- Rich standard library
- Multi-threading support
Basic Syntax
Variables and Data Types
// Primitive types
byte b = 127; // 8-bit
short s = 32767; // 16-bit
int i = 2147483647; // 32-bit
long l = 9223372036854775807L; // 64-bit
float f = 3.14f; // 32-bit floating point
double d = 3.14159; // 64-bit floating point
boolean bool = true; // true or false
char c = 'A'; // 16-bit Unicode
// Reference types
String str = "Hello, World!";
Integer num = 42; // Wrapper class
// Type conversion
int x = (int) 3.14; // Explicit casting
double y = 10; // Implicit casting
// Constants
final double PI = 3.14159;
final int MAX_SIZE = 100;
String Operations
// String creation
String s1 = "Hello";
String s2 = new String("World");
// String methods
int length = s1.length();
char ch = s1.charAt(0);
String sub = s1.substring(0, 3);
String upper = s1.toUpperCase();
String lower = s1.toLowerCase();
boolean startsWith = s1.startsWith("He");
boolean contains = s1.contains("ll");
// String comparison
boolean equals = s1.equals(s2);
boolean equalsIgnoreCase = s1.equalsIgnoreCase(s2);
int compare = s1.compareTo(s2);
// String concatenation
String full = s1 + " " + s2;
String joined = String.join(", ", "a", "b", "c");
// String formatting
String formatted = String.format("Name: %s, Age: %d", "Alice", 30);
// StringBuilder (mutable)
StringBuilder sb = new StringBuilder();
sb.append("Hello");
sb.append(" World");
String result = sb.toString();
Arrays and Collections
Arrays
// Array declaration
int[] numbers = new int[5];
int[] nums = {1, 2, 3, 4, 5};
String[] names = {"Alice", "Bob", "Charlie"};
// Accessing elements
int first = nums[0];
nums[2] = 10;
// Array length
int length = nums.length;
// Multi-dimensional arrays
int[][] matrix = new int[3][3];
int[][] grid = {{1, 2}, {3, 4}, {5, 6}};
// Arrays utility class
import java.util.Arrays;
Arrays.sort(nums); // Sort array
int index = Arrays.binarySearch(nums, 5); // Binary search
int[] copy = Arrays.copyOf(nums, nums.length); // Copy
boolean equal = Arrays.equals(nums, copy); // Compare
String str = Arrays.toString(nums); // Convert to string
ArrayList
import java.util.ArrayList;
// Creating ArrayList
ArrayList<String> list = new ArrayList<>();
ArrayList<Integer> numbers = new ArrayList<>(Arrays.asList(1, 2, 3));
// Adding elements
list.add("Apple");
list.add(0, "Banana"); // Add at index
list.addAll(Arrays.asList("Cherry", "Date"));
// Accessing elements
String first = list.get(0);
list.set(1, "Blueberry");
// Removing elements
list.remove(0);
list.remove("Apple");
list.clear();
// Operations
int size = list.size();
boolean empty = list.isEmpty();
boolean contains = list.contains("Apple");
int index = list.indexOf("Apple");
// Iteration
for (String item : list) {
System.out.println(item);
}
list.forEach(item -> System.out.println(item));
HashMap
import java.util.HashMap;
import java.util.Map;
// Creating HashMap
HashMap<String, Integer> map = new HashMap<>();
// Adding elements
map.put("Alice", 25);
map.put("Bob", 30);
map.putIfAbsent("Charlie", 35);
// Accessing elements
int age = map.get("Alice");
int defaultAge = map.getOrDefault("David", 0);
// Removing elements
map.remove("Bob");
// Operations
int size = map.size();
boolean empty = map.isEmpty();
boolean hasKey = map.containsKey("Alice");
boolean hasValue = map.containsValue(25);
// Iteration
for (Map.Entry<String, Integer> entry : map.entrySet()) {
System.out.println(entry.getKey() + ": " + entry.getValue());
}
map.forEach((key, value) ->
System.out.println(key + ": " + value));
Control Flow
If-Else
int age = 18;
if (age < 13) {
System.out.println("Child");
} else if (age < 20) {
System.out.println("Teenager");
} else {
System.out.println("Adult");
}
// Ternary operator
String status = (age >= 18) ? "Adult" : "Minor";
Switch
// Traditional switch
int day = 3;
switch (day) {
case 1:
System.out.println("Monday");
break;
case 2:
System.out.println("Tuesday");
break;
default:
System.out.println("Other day");
}
// Switch expression (Java 14+)
String dayName = switch (day) {
case 1 -> "Monday";
case 2 -> "Tuesday";
case 3 -> "Wednesday";
default -> "Other day";
};
Loops
// For loop
for (int i = 0; i < 5; i++) {
System.out.println(i);
}
// Enhanced for loop
int[] numbers = {1, 2, 3, 4, 5};
for (int num : numbers) {
System.out.println(num);
}
// While loop
int count = 0;
while (count < 5) {
System.out.println(count);
count++;
}
// Do-while loop
int i = 0;
do {
System.out.println(i);
i++;
} while (i < 5);
// Break and continue
for (int j = 0; j < 10; j++) {
if (j == 5) continue; // Skip 5
if (j == 8) break; // Stop at 8
System.out.println(j);
}
Object-Oriented Programming
Classes and Objects
public class Person {
// Fields (instance variables)
private String name;
private int age;
// Static field (class variable)
private static int count = 0;
// Constructor
public Person(String name, int age) {
this.name = name;
this.age = age;
count++;
}
// Default constructor
public Person() {
this("Unknown", 0);
}
// Getters and setters
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public int getAge() {
return age;
}
public void setAge(int age) {
if (age >= 0) {
this.age = age;
}
}
// Instance method
public void greet() {
System.out.println("Hello, I'm " + name);
}
// Static method
public static int getCount() {
return count;
}
// toString method
@Override
public String toString() {
return "Person{name='" + name + "', age=" + age + "}";
}
}
// Usage
Person person = new Person("Alice", 30);
person.greet();
System.out.println(person.toString());
Inheritance
// Base class
public class Animal {
protected String name;
public Animal(String name) {
this.name = name;
}
public void speak() {
System.out.println(name + " makes a sound");
}
}
// Derived class
public class Dog extends Animal {
private String breed;
public Dog(String name, String breed) {
super(name); // Call parent constructor
this.breed = breed;
}
@Override
public void speak() {
System.out.println(name + " barks");
}
public void fetch() {
System.out.println(name + " is fetching");
}
}
// Usage
Dog dog = new Dog("Buddy", "Golden Retriever");
dog.speak(); // "Buddy barks"
dog.fetch(); // "Buddy is fetching"
Interfaces
// Interface definition
public interface Drawable {
void draw(); // Abstract method
// Default method (Java 8+)
default void display() {
System.out.println("Displaying...");
}
// Static method (Java 8+)
static void info() {
System.out.println("Drawable interface");
}
}
// Implementation
public class Circle implements Drawable {
private double radius;
public Circle(double radius) {
this.radius = radius;
}
@Override
public void draw() {
System.out.println("Drawing circle with radius " + radius);
}
}
// Multiple interfaces
public class Square implements Drawable, Comparable<Square> {
private double side;
public Square(double side) {
this.side = side;
}
@Override
public void draw() {
System.out.println("Drawing square with side " + side);
}
@Override
public int compareTo(Square other) {
return Double.compare(this.side, other.side);
}
}
Abstract Classes
public abstract class Shape {
protected String color;
public Shape(String color) {
this.color = color;
}
// Abstract method
public abstract double area();
// Concrete method
public void setColor(String color) {
this.color = color;
}
public String getColor() {
return color;
}
}
public class Rectangle extends Shape {
private double width;
private double height;
public Rectangle(String color, double width, double height) {
super(color);
this.width = width;
this.height = height;
}
@Override
public double area() {
return width * height;
}
}
Exception Handling
// Try-catch
try {
int result = 10 / 0;
} catch (ArithmeticException e) {
System.out.println("Cannot divide by zero!");
}
// Multiple catch blocks
try {
int[] arr = new int[5];
arr[10] = 50;
} catch (ArrayIndexOutOfBoundsException e) {
System.out.println("Array index out of bounds");
} catch (Exception e) {
System.out.println("General exception: " + e.getMessage());
}
// Finally block
try {
// Code that may throw exception
} catch (Exception e) {
e.printStackTrace();
} finally {
System.out.println("This always executes");
}
// Try-with-resources (Java 7+)
try (BufferedReader br = new BufferedReader(new FileReader("file.txt"))) {
String line = br.readLine();
} catch (IOException e) {
e.printStackTrace();
}
// Throwing exceptions
public void checkAge(int age) throws IllegalArgumentException {
if (age < 0) {
throw new IllegalArgumentException("Age cannot be negative");
}
}
// Custom exception
public class InvalidAgeException extends Exception {
public InvalidAgeException(String message) {
super(message);
}
}
Streams and Lambdas (Java 8+)
Lambda Expressions
// Functional interface
@FunctionalInterface
interface Calculator {
int calculate(int a, int b);
}
// Lambda expression
Calculator add = (a, b) -> a + b;
Calculator multiply = (a, b) -> a * b;
System.out.println(add.calculate(5, 3)); // 8
System.out.println(multiply.calculate(5, 3)); // 15
// With collections
List<String> names = Arrays.asList("Alice", "Bob", "Charlie");
names.forEach(name -> System.out.println(name));
// Method reference
names.forEach(System.out::println);
Streams
import java.util.stream.*;
List<Integer> numbers = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
// Filter
List<Integer> evens = numbers.stream()
.filter(n -> n % 2 == 0)
.collect(Collectors.toList());
// Map
List<Integer> squared = numbers.stream()
.map(n -> n * n)
.collect(Collectors.toList());
// Reduce
int sum = numbers.stream()
.reduce(0, (a, b) -> a + b);
// Find
Optional<Integer> first = numbers.stream()
.filter(n -> n > 5)
.findFirst();
// Any/All match
boolean anyEven = numbers.stream().anyMatch(n -> n % 2 == 0);
boolean allPositive = numbers.stream().allMatch(n -> n > 0);
// Sorted
List<Integer> sorted = numbers.stream()
.sorted()
.collect(Collectors.toList());
// Limit and skip
List<Integer> limited = numbers.stream()
.limit(5)
.collect(Collectors.toList());
// Chaining operations
List<String> result = Arrays.asList("apple", "banana", "cherry", "date")
.stream()
.filter(s -> s.length() > 5)
.map(String::toUpperCase)
.sorted()
.collect(Collectors.toList());
Common Patterns
Singleton
public class Singleton {
private static Singleton instance;
private Singleton() {
// Private constructor
}
public static Singleton getInstance() {
if (instance == null) {
instance = new Singleton();
}
return instance;
}
}
// Thread-safe singleton
public class ThreadSafeSingleton {
private static volatile ThreadSafeSingleton instance;
private ThreadSafeSingleton() {}
public static ThreadSafeSingleton getInstance() {
if (instance == null) {
synchronized (ThreadSafeSingleton.class) {
if (instance == null) {
instance = new ThreadSafeSingleton();
}
}
}
return instance;
}
}
Factory Pattern
interface Animal {
void speak();
}
class Dog implements Animal {
public void speak() {
System.out.println("Woof!");
}
}
class Cat implements Animal {
public void speak() {
System.out.println("Meow!");
}
}
class AnimalFactory {
public static Animal createAnimal(String type) {
if (type.equals("dog")) {
return new Dog();
} else if (type.equals("cat")) {
return new Cat();
}
throw new IllegalArgumentException("Unknown animal type");
}
}
// Usage
Animal animal = AnimalFactory.createAnimal("dog");
animal.speak();
Builder Pattern
public class User {
private final String firstName;
private final String lastName;
private final int age;
private final String email;
private User(UserBuilder builder) {
this.firstName = builder.firstName;
this.lastName = builder.lastName;
this.age = builder.age;
this.email = builder.email;
}
public static class UserBuilder {
private String firstName;
private String lastName;
private int age;
private String email;
public UserBuilder(String firstName, String lastName) {
this.firstName = firstName;
this.lastName = lastName;
}
public UserBuilder age(int age) {
this.age = age;
return this;
}
public UserBuilder email(String email) {
this.email = email;
return this;
}
public User build() {
return new User(this);
}
}
}
// Usage
User user = new User.UserBuilder("Alice", "Smith")
.age(30)
.email("alice@example.com")
.build();
File I/O
import java.io.*;
import java.nio.file.*;
// Reading file
try (BufferedReader br = new BufferedReader(new FileReader("file.txt"))) {
String line;
while ((line = br.readLine()) != null) {
System.out.println(line);
}
} catch (IOException e) {
e.printStackTrace();
}
// Writing file
try (BufferedWriter bw = new BufferedWriter(new FileWriter("file.txt"))) {
bw.write("Hello, World!");
bw.newLine();
bw.write("Second line");
} catch (IOException e) {
e.printStackTrace();
}
// Using Files class (Java 7+)
try {
// Read all lines
List<String> lines = Files.readAllLines(Paths.get("file.txt"));
// Write lines
Files.write(Paths.get("output.txt"),
Arrays.asList("Line 1", "Line 2"));
// Copy file
Files.copy(Paths.get("source.txt"), Paths.get("dest.txt"));
// Delete file
Files.delete(Paths.get("file.txt"));
} catch (IOException e) {
e.printStackTrace();
}
Best Practices
-
Follow naming conventions
- Classes: PascalCase (
MyClass) - Methods/variables: camelCase (
myMethod) - Constants: UPPER_SNAKE_CASE (
MAX_SIZE)
- Classes: PascalCase (
-
Use meaningful names
// Good int studentCount = 50; // Bad int sc = 50; -
Keep methods small - One responsibility per method
-
Use StringBuilder for string concatenation in loops
-
Close resources - Use try-with-resources
-
Handle exceptions properly - Don't swallow exceptions
-
Use generics for type safety
-
Follow SOLID principles
-
Use Optional to avoid null checks (Java 8+)
Optional<String> optional = Optional.ofNullable(getValue()); String value = optional.orElse("default"); -
Use streams for collection processing (Java 8+)
Common Libraries/Frameworks
- Spring Boot: Application framework
- Hibernate: ORM framework
- JUnit: Testing framework
- Maven/Gradle: Build tools
- Jackson: JSON processing
- Log4j/SLF4J: Logging
- Apache Commons: Utility libraries
Go Programming
Overview
Go (Golang) is a statically typed, compiled programming language designed at Google. It's known for its simplicity, efficiency, and excellent support for concurrent programming.
Key Features:
- Fast compilation and execution
- Built-in concurrency (goroutines and channels)
- Garbage collection
- Strong static typing with type inference
- Simple and clean syntax
- Excellent standard library
- Cross-platform compilation
Basic Syntax
Variables and Data Types
package main
import "fmt"
func main() {
// Variable declaration
var name string = "Alice"
var age int = 30
// Short declaration (type inference)
city := "NYC"
isActive := true
// Multiple declarations
var x, y, z int = 1, 2, 3
a, b := 10, 20
// Constants
const PI = 3.14159
const MaxSize = 100
// Zero values (default values)
var num int // 0
var str string // ""
var flag bool // false
var ptr *int // nil
fmt.Println(name, age, city, isActive)
}
Data Types
// Basic types
var i int = 42 // Platform-dependent (32 or 64 bit)
var i8 int8 = 127 // 8-bit
var i16 int16 = 32767 // 16-bit
var i32 int32 = 2147483647 // 32-bit (rune alias)
var i64 int64 = 9223372036854775807 // 64-bit
var u uint = 42 // Unsigned, platform-dependent
var u8 uint8 = 255 // 8-bit (byte alias)
var f32 float32 = 3.14 // 32-bit float
var f64 float64 = 3.14159 // 64-bit float
var c64 complex64 = 1 + 2i
var c128 complex128 = 1 + 2i
var b bool = true
var r rune = 'A' // Unicode code point (int32)
var by byte = 65 // Alias for uint8
var str string = "Hello, 世界"
// Type conversion
var x int = 42
var y float64 = float64(x)
var z uint = uint(x)
Strings
// String operations
s1 := "Hello"
s2 := "World"
// Concatenation
full := s1 + " " + s2
// Length (bytes, not runes)
length := len(s1)
// Accessing bytes
firstByte := s1[0]
// Substrings
sub := s1[1:4] // "ell"
// String comparison
if s1 == s2 {
fmt.Println("Equal")
}
// Multi-line strings
multiline := `This is a
multi-line
string`
// String iteration
for i, char := range "Hello" {
fmt.Printf("%d: %c\n", i, char)
}
// String formatting
import "fmt"
formatted := fmt.Sprintf("Name: %s, Age: %d", "Alice", 30)
// String conversion
import "strconv"
numStr := strconv.Itoa(42) // int to string
num, err := strconv.Atoi("42") // string to int
Arrays and Slices
Arrays
// Fixed-size arrays
var arr [5]int
arr[0] = 1
// Array literal
numbers := [5]int{1, 2, 3, 4, 5}
// Compiler counts length
auto := [...]int{1, 2, 3, 4}
// Multi-dimensional arrays
matrix := [3][3]int{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
}
// Array length
length := len(numbers)
// Iterate over array
for i, v := range numbers {
fmt.Printf("%d: %d\n", i, v)
}
Slices (Dynamic Arrays)
// Creating slices
var slice []int // nil slice
slice = []int{1, 2, 3, 4, 5} // slice literal
slice = make([]int, 5) // length 5, all zeros
slice = make([]int, 5, 10) // length 5, capacity 10
// Append to slice
slice = append(slice, 6)
slice = append(slice, 7, 8, 9)
// Slice operations
arr := []int{1, 2, 3, 4, 5}
sub := arr[1:4] // [2, 3, 4]
first := arr[:3] // [1, 2, 3]
last := arr[3:] // [4, 5]
// Length and capacity
len := len(slice)
cap := cap(slice)
// Copy slices
src := []int{1, 2, 3}
dst := make([]int, len(src))
copy(dst, src)
// 2D slices
matrix := [][]int{
{1, 2, 3},
{4, 5, 6},
}
// Iterate
for i, v := range slice {
fmt.Printf("%d: %d\n", i, v)
}
Maps
// Creating maps
var m map[string]int // nil map
m = make(map[string]int) // empty map
m = map[string]int{ // map literal
"Alice": 25,
"Bob": 30,
}
// Adding/updating elements
m["Charlie"] = 35
m["Alice"] = 26
// Accessing elements
age := m["Alice"]
// Check if key exists
age, ok := m["Alice"]
if ok {
fmt.Println("Alice's age:", age)
}
// Delete element
delete(m, "Bob")
// Iterate over map
for key, value := range m {
fmt.Printf("%s: %d\n", key, value)
}
// Map length
size := len(m)
// Nested maps
nested := map[string]map[string]int{
"group1": {
"Alice": 25,
"Bob": 30,
},
"group2": {
"Charlie": 35,
},
}
Control Flow
If-Else
age := 18
if age < 13 {
fmt.Println("Child")
} else if age < 20 {
fmt.Println("Teenager")
} else {
fmt.Println("Adult")
}
// If with initialization
if num := 42; num > 0 {
fmt.Println("Positive")
}
// Error checking pattern
if err := someFunction(); err != nil {
fmt.Println("Error:", err)
}
Switch
// Basic switch
day := 3
switch day {
case 1:
fmt.Println("Monday")
case 2:
fmt.Println("Tuesday")
case 3:
fmt.Println("Wednesday")
default:
fmt.Println("Other day")
}
// Multiple cases
switch day {
case 1, 2, 3, 4, 5:
fmt.Println("Weekday")
case 6, 7:
fmt.Println("Weekend")
}
// Switch with condition
num := 42
switch {
case num < 0:
fmt.Println("Negative")
case num == 0:
fmt.Println("Zero")
case num > 0:
fmt.Println("Positive")
}
// Type switch
var i interface{} = "hello"
switch v := i.(type) {
case string:
fmt.Println("String:", v)
case int:
fmt.Println("Int:", v)
default:
fmt.Println("Unknown type")
}
Loops
// For loop (only loop in Go)
for i := 0; i < 5; i++ {
fmt.Println(i)
}
// While-style loop
count := 0
for count < 5 {
fmt.Println(count)
count++
}
// Infinite loop
for {
fmt.Println("Forever")
break // Exit loop
}
// Range over slice
numbers := []int{1, 2, 3, 4, 5}
for i, v := range numbers {
fmt.Printf("%d: %d\n", i, v)
}
// Range over map
m := map[string]int{"a": 1, "b": 2}
for key, value := range m {
fmt.Printf("%s: %d\n", key, value)
}
// Ignore index/value with _
for _, v := range numbers {
fmt.Println(v)
}
// Break and continue
for i := 0; i < 10; i++ {
if i == 5 {
continue
}
if i == 8 {
break
}
fmt.Println(i)
}
Functions
Basic Functions
// Simple function
func greet(name string) {
fmt.Println("Hello,", name)
}
// Function with return value
func add(a, b int) int {
return a + b
}
// Multiple parameters of same type
func multiply(a, b, c int) int {
return a * b * c
}
// Multiple return values
func swap(a, b string) (string, string) {
return b, a
}
// Named return values
func divide(a, b float64) (result float64, err error) {
if b == 0 {
err = fmt.Errorf("division by zero")
return
}
result = a / b
return // Naked return
}
// Variadic functions
func sum(numbers ...int) int {
total := 0
for _, num := range numbers {
total += num
}
return total
}
// Usage
result := sum(1, 2, 3, 4, 5)
Anonymous Functions and Closures
// Anonymous function
add := func(a, b int) int {
return a + b
}
result := add(5, 3)
// Immediately invoked function
result := func(a, b int) int {
return a + b
}(5, 3)
// Closure
func counter() func() int {
count := 0
return func() int {
count++
return count
}
}
c := counter()
fmt.Println(c()) // 1
fmt.Println(c()) // 2
fmt.Println(c()) // 3
Defer
// Defer executes at function end
func example() {
defer fmt.Println("World")
fmt.Println("Hello")
}
// Output: Hello
// World
// Multiple defers (LIFO order)
func multiDefer() {
defer fmt.Println("1")
defer fmt.Println("2")
defer fmt.Println("3")
}
// Output: 3, 2, 1
// Common pattern: cleanup
func readFile(filename string) error {
f, err := os.Open(filename)
if err != nil {
return err
}
defer f.Close()
// Work with file
return nil
}
Structs and Methods
Structs
// Define struct
type Person struct {
Name string
Age int
Email string
}
// Create struct
p1 := Person{
Name: "Alice",
Age: 30,
Email: "alice@example.com",
}
// Short form
p2 := Person{"Bob", 25, "bob@example.com"}
// Anonymous struct
person := struct {
name string
age int
}{
name: "Charlie",
age: 35,
}
// Accessing fields
fmt.Println(p1.Name)
p1.Age = 31
// Pointer to struct
p := &Person{Name: "Alice", Age: 30}
p.Age = 31 // Automatic dereferencing
// Embedded structs
type Address struct {
City string
Country string
}
type Employee struct {
Person // Embedded struct
Address // Embedded struct
Salary float64
}
emp := Employee{
Person: Person{Name: "Alice", Age: 30},
Address: Address{City: "NYC", Country: "USA"},
Salary: 100000,
}
// Access embedded fields
fmt.Println(emp.Name) // From Person
fmt.Println(emp.City) // From Address
Methods
// Method on struct
type Rectangle struct {
Width float64
Height float64
}
// Value receiver
func (r Rectangle) Area() float64 {
return r.Width * r.Height
}
// Pointer receiver (can modify)
func (r *Rectangle) Scale(factor float64) {
r.Width *= factor
r.Height *= factor
}
// Usage
rect := Rectangle{Width: 10, Height: 5}
area := rect.Area()
rect.Scale(2)
// Method on any type
type MyInt int
func (m MyInt) Double() MyInt {
return m * 2
}
num := MyInt(5)
result := num.Double() // 10
Interfaces
// Define interface
type Shape interface {
Area() float64
Perimeter() float64
}
// Implement interface (implicit)
type Circle struct {
Radius float64
}
func (c Circle) Area() float64 {
return 3.14159 * c.Radius * c.Radius
}
func (c Circle) Perimeter() float64 {
return 2 * 3.14159 * c.Radius
}
type Rectangle struct {
Width, Height float64
}
func (r Rectangle) Area() float64 {
return r.Width * r.Height
}
func (r Rectangle) Perimeter() float64 {
return 2 * (r.Width + r.Height)
}
// Use interface
func printArea(s Shape) {
fmt.Printf("Area: %.2f\n", s.Area())
}
// Usage
c := Circle{Radius: 5}
r := Rectangle{Width: 10, Height: 5}
printArea(c)
printArea(r)
// Empty interface (any type)
func printAnything(v interface{}) {
fmt.Println(v)
}
printAnything(42)
printAnything("hello")
printAnything(true)
// Type assertion
var i interface{} = "hello"
s := i.(string)
s, ok := i.(string) // Safe type assertion
// Type switch
switch v := i.(type) {
case string:
fmt.Println("String:", v)
case int:
fmt.Println("Int:", v)
default:
fmt.Println("Unknown")
}
Concurrency
Goroutines
// Start goroutine
go func() {
fmt.Println("Hello from goroutine")
}()
// Multiple goroutines
for i := 0; i < 5; i++ {
go func(n int) {
fmt.Println("Goroutine", n)
}(i)
}
// Wait for goroutines
import "sync"
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func(n int) {
defer wg.Done()
fmt.Println("Worker", n)
}(i)
}
wg.Wait()
Channels
// Create channel
ch := make(chan int)
// Buffered channel
ch := make(chan int, 5)
// Send to channel
go func() {
ch <- 42
}()
// Receive from channel
value := <-ch
// Close channel
close(ch)
// Range over channel
go func() {
for i := 0; i < 5; i++ {
ch <- i
}
close(ch)
}()
for value := range ch {
fmt.Println(value)
}
// Select statement
ch1 := make(chan string)
ch2 := make(chan string)
go func() {
ch1 <- "from ch1"
}()
go func() {
ch2 <- "from ch2"
}()
select {
case msg1 := <-ch1:
fmt.Println(msg1)
case msg2 := <-ch2:
fmt.Println(msg2)
case <-time.After(1 * time.Second):
fmt.Println("timeout")
}
Sync Package
import "sync"
// Mutex
var (
mu sync.Mutex
count int
)
func increment() {
mu.Lock()
defer mu.Unlock()
count++
}
// RWMutex (multiple readers, single writer)
var (
rwMu sync.RWMutex
data map[string]int
)
func read(key string) int {
rwMu.RLock()
defer rwMu.RUnlock()
return data[key]
}
func write(key string, value int) {
rwMu.Lock()
defer rwMu.Unlock()
data[key] = value
}
// Once (execute only once)
var once sync.Once
func initialize() {
once.Do(func() {
fmt.Println("Initialized")
})
}
Error Handling
import "errors"
import "fmt"
// Return error
func divide(a, b float64) (float64, error) {
if b == 0 {
return 0, errors.New("division by zero")
}
return a / b, nil
}
// Formatted error
func validateAge(age int) error {
if age < 0 {
return fmt.Errorf("invalid age: %d", age)
}
return nil
}
// Custom error type
type ValidationError struct {
Field string
Value interface{}
}
func (e *ValidationError) Error() string {
return fmt.Sprintf("validation error: %s = %v", e.Field, e.Value)
}
// Error handling pattern
result, err := divide(10, 0)
if err != nil {
fmt.Println("Error:", err)
return
}
fmt.Println("Result:", result)
// Panic and recover
func riskyOperation() {
defer func() {
if r := recover(); r != nil {
fmt.Println("Recovered from panic:", r)
}
}()
panic("something went wrong")
}
Packages and Imports
// Package declaration
package main
// Import single package
import "fmt"
// Import multiple packages
import (
"fmt"
"math"
"strings"
)
// Aliased import
import f "fmt"
f.Println("Hello")
// Blank import (side effects)
import _ "database/sql/driver"
// Creating a package
// mypackage/mypackage.go
package mypackage
// Exported (capitalized)
func PublicFunction() {
fmt.Println("Public")
}
// Not exported (lowercase)
func privateFunction() {
fmt.Println("Private")
}
// Using the package
import "myproject/mypackage"
mypackage.PublicFunction()
File I/O
import (
"bufio"
"fmt"
"io/ioutil"
"os"
)
// Read entire file
data, err := ioutil.ReadFile("file.txt")
if err != nil {
panic(err)
}
fmt.Println(string(data))
// Write file
err = ioutil.WriteFile("output.txt", []byte("Hello"), 0644)
// Open file
file, err := os.Open("file.txt")
if err != nil {
panic(err)
}
defer file.Close()
// Read line by line
scanner := bufio.NewScanner(file)
for scanner.Scan() {
fmt.Println(scanner.Text())
}
// Write to file
file, err := os.Create("output.txt")
if err != nil {
panic(err)
}
defer file.Close()
writer := bufio.NewWriter(file)
writer.WriteString("Hello, World!\n")
writer.Flush()
Common Patterns
Singleton
import "sync"
type singleton struct {
data string
}
var (
instance *singleton
once sync.Once
)
func GetInstance() *singleton {
once.Do(func() {
instance = &singleton{data: "singleton"}
})
return instance
}
Factory Pattern
type Animal interface {
Speak() string
}
type Dog struct{}
func (d Dog) Speak() string { return "Woof!" }
type Cat struct{}
func (c Cat) Speak() string { return "Meow!" }
func NewAnimal(animalType string) Animal {
switch animalType {
case "dog":
return Dog{}
case "cat":
return Cat{}
default:
return nil
}
}
Builder Pattern
type User struct {
firstName string
lastName string
age int
email string
}
type UserBuilder struct {
user User
}
func NewUserBuilder() *UserBuilder {
return &UserBuilder{}
}
func (b *UserBuilder) FirstName(name string) *UserBuilder {
b.user.firstName = name
return b
}
func (b *UserBuilder) LastName(name string) *UserBuilder {
b.user.lastName = name
return b
}
func (b *UserBuilder) Age(age int) *UserBuilder {
b.user.age = age
return b
}
func (b *UserBuilder) Email(email string) *UserBuilder {
b.user.email = email
return b
}
func (b *UserBuilder) Build() User {
return b.user
}
// Usage
user := NewUserBuilder().
FirstName("Alice").
LastName("Smith").
Age(30).
Email("alice@example.com").
Build()
Testing
// main.go
package main
func Add(a, b int) int {
return a + b
}
// main_test.go
package main
import "testing"
func TestAdd(t *testing.T) {
result := Add(2, 3)
expected := 5
if result != expected {
t.Errorf("Add(2, 3) = %d; want %d", result, expected)
}
}
func TestAddNegative(t *testing.T) {
result := Add(-1, -1)
expected := -2
if result != expected {
t.Errorf("Add(-1, -1) = %d; want %d", result, expected)
}
}
// Table-driven tests
func TestAddTable(t *testing.T) {
tests := []struct {
a, b, expected int
}{
{1, 2, 3},
{0, 0, 0},
{-1, 1, 0},
{10, 20, 30},
}
for _, tt := range tests {
result := Add(tt.a, tt.b)
if result != tt.expected {
t.Errorf("Add(%d, %d) = %d; want %d",
tt.a, tt.b, result, tt.expected)
}
}
}
// Run tests: go test
// Run with coverage: go test -cover
Best Practices
-
Use gofmt - Format code automatically
gofmt -w . -
Use golint - Check code style
golint ./... -
Error handling - Always check errors
if err != nil { return err } -
Use interfaces - Program to interfaces, not implementations
-
Prefer composition over inheritance
-
Keep functions small - Single responsibility
-
Use meaningful names - Clear and descriptive
-
Document exported items - Comments for public API
// Add returns the sum of two integers. func Add(a, b int) int { return a + b } -
Use context for cancellation and timeouts
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() -
Avoid global state - Pass dependencies explicitly
Common Libraries
- gorilla/mux: HTTP router
- gin: Web framework
- gorm: ORM
- viper: Configuration
- cobra: CLI applications
- logrus: Logging
- testify: Testing toolkit
- zap: Fast logging
- grpc: RPC framework
- redis: Redis client
Go Modules
# Initialize module
go mod init github.com/username/project
# Add dependency
go get github.com/gin-gonic/gin
# Update dependencies
go get -u
# Tidy dependencies
go mod tidy
# Vendor dependencies
go mod vendor
Useful Commands
# Run program
go run main.go
# Build executable
go build
# Install binary
go install
# Format code
go fmt ./...
# Run tests
go test ./...
# Get dependencies
go get package
# Show documentation
go doc fmt.Println
Lua Programming
Overview
Lua is a lightweight, high-level, multi-paradigm programming language designed primarily for embedded use in applications. It's known for its simplicity, efficiency, and powerful data description constructs.
Key Features:
- Lightweight and embeddable
- Fast execution
- Simple and clean syntax
- Dynamic typing
- Automatic memory management (garbage collection)
- First-class functions
- Powerful table data structure
- Coroutines for concurrency
Common Uses:
- Game scripting (World of Warcraft, Roblox)
- Embedded systems
- Configuration files
- Application scripting
- Web development (OpenResty)
Basic Syntax
Variables and Data Types
-- Variables (global by default)
name = "Alice"
age = 30
pi = 3.14159
-- Local variables (recommended)
local x = 10
local y = 20
-- Multiple assignment
local a, b, c = 1, 2, 3
-- Swap variables
a, b = b, a
-- Nil (undefined/null)
local nothing = nil
-- Comments
-- Single line comment
--[[
Multi-line
comment
]]
-- Data types
local num = 42 -- number
local str = "Hello" -- string
local bool = true -- boolean
local tbl = {1, 2, 3} -- table
local func = function() end -- function
local thread = coroutine.create(function() end) -- thread
local nothing = nil -- nil
-- Type checking
print(type(num)) -- number
print(type(str)) -- string
print(type(bool)) -- boolean
print(type(tbl)) -- table
print(type(func)) -- function
Strings
-- String creation
local s1 = "Hello"
local s2 = 'World'
local s3 = [[Multi-line
string]]
-- String concatenation
local full = s1 .. " " .. s2
-- String length
local len = #s1
local len2 = string.len(s1)
-- String methods
local upper = string.upper(s1) -- "HELLO"
local lower = string.lower(s1) -- "hello"
local sub = string.sub(s1, 1, 3) -- "Hel"
local find = string.find(s1, "ll") -- 3, 4
local replace = string.gsub(s1, "l", "L") -- "HeLLo"
-- String formatting
local formatted = string.format("Name: %s, Age: %d", "Alice", 30)
-- String to number
local num = tonumber("42")
local str = tostring(42)
-- String repetition
local repeated = string.rep("Ha", 3) -- "HaHaHa"
-- Pattern matching (similar to regex)
local match = string.match("Hello123", "%d+") -- "123"
-- Iterate characters
for i = 1, #s1 do
local char = string.sub(s1, i, i)
print(char)
end
Tables
Tables are the only data structure in Lua - they can represent arrays, dictionaries, objects, and more.
Arrays (1-indexed)
-- Array creation
local arr = {10, 20, 30, 40, 50}
-- Accessing elements (1-indexed!)
print(arr[1]) -- 10
-- Modifying elements
arr[1] = 15
-- Array length
local len = #arr
-- Append to array
table.insert(arr, 60) -- Append to end
table.insert(arr, 2, 25) -- Insert at position 2
-- Remove from array
local last = table.remove(arr) -- Remove last
local second = table.remove(arr, 2) -- Remove at position 2
-- Iterate array
for i = 1, #arr do
print(i, arr[i])
end
-- Iterate with ipairs
for i, v in ipairs(arr) do
print(i, v)
end
-- Table functions
table.sort(arr) -- Sort ascending
table.sort(arr, function(a, b) return a > b end) -- Sort descending
local str = table.concat(arr, ", ") -- Join with separator
Dictionaries/Maps
-- Dictionary creation
local person = {
name = "Alice",
age = 30,
city = "NYC"
}
-- Alternative syntax
local person2 = {
["name"] = "Bob",
["age"] = 25
}
-- Accessing elements
print(person.name) -- Dot notation
print(person["age"]) -- Bracket notation
-- Adding/modifying
person.email = "alice@example.com"
person["phone"] = "123-456-7890"
-- Removing
person.email = nil
-- Iterate dictionary
for key, value in pairs(person) do
print(key, value)
end
-- Check if key exists
if person.name then
print("Name exists")
end
-- Nested tables
local nested = {
user = {
name = "Alice",
address = {
city = "NYC",
country = "USA"
}
}
}
print(nested.user.address.city)
Mixed Tables
-- Table with both array and dictionary parts
local mixed = {
"first", -- [1] = "first"
"second", -- [2] = "second"
name = "Alice",
age = 30
}
print(mixed[1]) -- "first"
print(mixed.name) -- "Alice"
-- Length only counts array part
print(#mixed) -- 2
Control Flow
If-Else
local age = 18
if age < 13 then
print("Child")
elseif age < 20 then
print("Teenager")
else
print("Adult")
end
-- Logical operators: and, or, not
if age >= 18 and age < 65 then
print("Working age")
end
if age < 18 or age > 65 then
print("Not working age")
end
if not (age < 18) then
print("Adult")
end
-- Ternary-like operator
local status = age >= 18 and "Adult" or "Minor"
Loops
-- While loop
local count = 0
while count < 5 do
print(count)
count = count + 1
end
-- Repeat-until loop (do-while)
local i = 0
repeat
print(i)
i = i + 1
until i >= 5
-- For loop (numeric)
for i = 1, 5 do
print(i) -- 1, 2, 3, 4, 5
end
-- For loop with step
for i = 0, 10, 2 do
print(i) -- 0, 2, 4, 6, 8, 10
end
-- For loop (reverse)
for i = 5, 1, -1 do
print(i) -- 5, 4, 3, 2, 1
end
-- Iterate array with ipairs
local arr = {10, 20, 30, 40, 50}
for i, v in ipairs(arr) do
print(i, v)
end
-- Iterate table with pairs
local person = {name = "Alice", age = 30}
for key, value in pairs(person) do
print(key, value)
end
-- Break
for i = 1, 10 do
if i == 5 then
break
end
print(i)
end
-- No continue in Lua (use goto in Lua 5.2+)
for i = 1, 10 do
if i == 5 then
goto continue
end
print(i)
::continue::
end
Functions
-- Basic function
function greet(name)
print("Hello, " .. name)
end
greet("Alice")
-- Function with return value
function add(a, b)
return a + b
end
local result = add(5, 3)
-- Multiple return values
function swap(a, b)
return b, a
end
local x, y = swap(10, 20)
-- Default parameters (manual)
function greet(name)
name = name or "World"
print("Hello, " .. name)
end
-- Variable arguments
function sum(...)
local total = 0
for _, v in ipairs({...}) do
total = total + v
end
return total
end
print(sum(1, 2, 3, 4, 5)) -- 15
-- Anonymous functions
local add = function(a, b)
return a + b
end
-- Function as argument
function applyOperation(a, b, operation)
return operation(a, b)
end
local result = applyOperation(5, 3, function(x, y)
return x * y
end)
-- Closures
function counter()
local count = 0
return function()
count = count + 1
return count
end
end
local c = counter()
print(c()) -- 1
print(c()) -- 2
print(c()) -- 3
-- Local functions
local function helper()
print("Helper function")
end
-- Recursive functions need forward declaration
local factorial
factorial = function(n)
if n <= 1 then
return 1
else
return n * factorial(n - 1)
end
end
Object-Oriented Programming
Lua doesn't have built-in classes, but tables and metatables provide OOP features.
Tables as Objects
-- Simple object
local person = {
name = "Alice",
age = 30,
greet = function(self)
print("Hello, I'm " .. self.name)
end
}
person:greet() -- Colon syntax passes self automatically
-- Equivalent to: person.greet(person)
Metatables and Classes
-- Define a class
local Person = {}
Person.__index = Person
-- Constructor
function Person:new(name, age)
local instance = setmetatable({}, Person)
instance.name = name
instance.age = age
return instance
end
-- Methods
function Person:greet()
print("Hello, I'm " .. self.name)
end
function Person:getAge()
return self.age
end
function Person:setAge(age)
self.age = age
end
-- Usage
local alice = Person:new("Alice", 30)
alice:greet()
print(alice:getAge())
-- Inheritance
local Employee = setmetatable({}, {__index = Person})
Employee.__index = Employee
function Employee:new(name, age, salary)
local instance = Person:new(name, age)
setmetatable(instance, Employee)
instance.salary = salary
return instance
end
function Employee:getSalary()
return self.salary
end
-- Usage
local emp = Employee:new("Bob", 25, 50000)
emp:greet() -- Inherited from Person
print(emp:getSalary())
-- Operator overloading
local Vector = {}
Vector.__index = Vector
function Vector:new(x, y)
return setmetatable({x = x, y = y}, Vector)
end
-- Overload addition
Vector.__add = function(a, b)
return Vector:new(a.x + b.x, a.y + b.y)
end
-- Overload tostring
Vector.__tostring = function(v)
return "(" .. v.x .. ", " .. v.y .. ")"
end
local v1 = Vector:new(1, 2)
local v2 = Vector:new(3, 4)
local v3 = v1 + v2
print(v3) -- (4, 6)
Modules
-- mymodule.lua
local M = {}
-- Private function
local function private()
print("Private")
end
-- Public function
function M.public()
print("Public")
end
function M.add(a, b)
return a + b
end
return M
-- main.lua
local mymodule = require("mymodule")
mymodule.public()
local result = mymodule.add(5, 3)
Error Handling
-- pcall (protected call)
local success, result = pcall(function()
return 10 / 0
end)
if success then
print("Result:", result)
else
print("Error:", result)
end
-- Error with message
function divide(a, b)
if b == 0 then
error("Division by zero")
end
return a / b
end
local success, result = pcall(divide, 10, 0)
if not success then
print("Error:", result)
end
-- Assert
local function checkPositive(n)
assert(n > 0, "Number must be positive")
return n
end
-- xpcall (with error handler)
local function errorHandler(err)
print("Error occurred:", err)
return err
end
local success, result = xpcall(function()
error("Something went wrong")
end, errorHandler)
File I/O
-- Read entire file
local file = io.open("input.txt", "r")
if file then
local content = file:read("*all")
print(content)
file:close()
end
-- Read line by line
local file = io.open("input.txt", "r")
if file then
for line in file:lines() do
print(line)
end
file:close()
end
-- Write file
local file = io.open("output.txt", "w")
if file then
file:write("Hello, World!\n")
file:write("Second line\n")
file:close()
end
-- Append to file
local file = io.open("output.txt", "a")
if file then
file:write("Appended line\n")
file:close()
end
-- Using io.input and io.output
io.input("input.txt")
local content = io.read("*all")
io.close()
io.output("output.txt")
io.write("Hello\n")
io.close()
Coroutines
-- Create coroutine
local co = coroutine.create(function()
for i = 1, 5 do
print("Coroutine:", i)
coroutine.yield() -- Pause execution
end
end)
-- Resume coroutine
coroutine.resume(co) -- Prints 1
coroutine.resume(co) -- Prints 2
coroutine.resume(co) -- Prints 3
-- Check status
print(coroutine.status(co)) -- suspended or running or dead
-- Producer-consumer pattern
local function producer()
return coroutine.create(function()
for i = 1, 5 do
coroutine.yield(i)
end
end)
end
local function consumer(prod)
while true do
local status, value = coroutine.resume(prod)
if not status then break end
print("Received:", value)
end
end
consumer(producer())
Common Patterns
Singleton Pattern
local Singleton = {}
local instance
function Singleton:getInstance()
if not instance then
instance = {data = "singleton"}
end
return instance
end
local s1 = Singleton:getInstance()
local s2 = Singleton:getInstance()
print(s1 == s2) -- true
Factory Pattern
local AnimalFactory = {}
function AnimalFactory:create(animalType)
if animalType == "dog" then
return {speak = function() return "Woof!" end}
elseif animalType == "cat" then
return {speak = function() return "Meow!" end}
end
end
local dog = AnimalFactory:create("dog")
print(dog.speak())
Observer Pattern
local Subject = {}
Subject.__index = Subject
function Subject:new()
return setmetatable({observers = {}}, Subject)
end
function Subject:attach(observer)
table.insert(self.observers, observer)
end
function Subject:detach(observer)
for i, obs in ipairs(self.observers) do
if obs == observer then
table.remove(self.observers, i)
break
end
end
end
function Subject:notify(data)
for _, observer in ipairs(self.observers) do
observer:update(data)
end
end
-- Observer
local Observer = {}
Observer.__index = Observer
function Observer:new(name)
return setmetatable({name = name}, Observer)
end
function Observer:update(data)
print(self.name .. " received: " .. data)
end
-- Usage
local subject = Subject:new()
local obs1 = Observer:new("Observer1")
local obs2 = Observer:new("Observer2")
subject:attach(obs1)
subject:attach(obs2)
subject:notify("Event occurred!")
Standard Library
-- Math
print(math.pi)
print(math.abs(-5))
print(math.floor(3.7))
print(math.ceil(3.2))
print(math.max(1, 5, 3))
print(math.min(1, 5, 3))
print(math.random()) -- Random [0, 1)
print(math.random(10)) -- Random [1, 10]
print(math.random(5, 10)) -- Random [5, 10]
-- String
print(string.upper("hello"))
print(string.lower("HELLO"))
print(string.reverse("hello"))
-- Table
local arr = {3, 1, 4, 1, 5}
table.sort(arr)
print(table.concat(arr, ", "))
-- OS
print(os.time())
print(os.date("%Y-%m-%d %H:%M:%S"))
os.execute("ls") -- Execute shell command
-- Pairs / IPairs
local t = {10, 20, 30, x = 1, y = 2}
for k, v in pairs(t) do -- All elements
print(k, v)
end
for i, v in ipairs(t) do -- Only array part
print(i, v)
end
Best Practices
-
Use local variables - Faster and avoids global pollution
local x = 10 -- Good x = 10 -- Bad (global) -
Prefer ipairs for arrays - More efficient than pairs
-
Use metatables for OOP and operator overloading
-
Always close files after use
-
Use pcall for error handling in production
-
Avoid goto - Use structured control flow
-
Use string.format for complex string formatting
-
Cache table lookups in loops
local insert = table.insert for i = 1, 1000 do insert(arr, i) end -
Use semicolons sparingly - Optional in Lua
-
Follow naming conventions
- Variables: snake_case
- Constants: UPPER_CASE
- Functions: camelCase or snake_case
Common Use Cases
Configuration Files
-- config.lua
return {
database = {
host = "localhost",
port = 5432,
name = "mydb"
},
server = {
port = 8080,
workers = 4
}
}
-- Load config
local config = require("config")
print(config.database.host)
Game Scripting
-- Define enemy
local Enemy = {}
Enemy.__index = Enemy
function Enemy:new(name, health, damage)
return setmetatable({
name = name,
health = health,
damage = damage
}, Enemy)
end
function Enemy:attack(target)
target.health = target.health - self.damage
print(self.name .. " attacks " .. target.name)
end
function Enemy:isAlive()
return self.health > 0
end
Lua Versions
- Lua 5.1: Widely used in games (WoW, Roblox)
- Lua 5.2: Added
goto,_ENV - Lua 5.3: Integer subtype, bitwise operators
- Lua 5.4: To-be-closed variables, const variables
- LuaJIT: JIT compiler, very fast (used in OpenResty)
Useful Libraries
- LuaSocket: Networking
- LuaFileSystem: File system operations
- Penlight: Extended standard library
- LÖVE: Game framework
- OpenResty: Web platform (Nginx + Lua)
- LuaRocks: Package manager
Rust Programming
Overview
Rust is a systems programming language focused on safety, speed, and concurrency. It achieves memory safety without garbage collection through its ownership system.
Key Features:
- Memory safety without garbage collection
- Zero-cost abstractions
- Ownership and borrowing system
- Guaranteed thread safety
- Pattern matching
- Type inference
- Powerful macro system
- Excellent tooling (cargo, rustfmt, clippy)
Basic Syntax
Variables and Data Types
fn main() { // Immutable by default let x = 5; // x = 6; // Error! Cannot mutate immutable variable // Mutable variable let mut y = 5; y = 6; // OK // Constants (must have type annotation) const MAX_POINTS: u32 = 100_000; // Shadowing (redefining variable) let x = 5; let x = x + 1; let x = x * 2; // x is now 12 // Type annotation let guess: u32 = "42".parse().expect("Not a number!"); // Scalar types let integer: i32 = 42; let float: f64 = 3.14; let boolean: bool = true; let character: char = 'A'; // Integer types: i8, i16, i32, i64, i128, isize // Unsigned: u8, u16, u32, u64, u128, usize let signed: i8 = -127; let unsigned: u8 = 255; // Number literals let decimal = 98_222; let hex = 0xff; let octal = 0o77; let binary = 0b1111_0000; let byte = b'A'; // u8 only }
Strings
fn main() { // String slice (immutable, fixed size) let s1: &str = "Hello"; // String (mutable, growable) let mut s2 = String::from("Hello"); s2.push_str(", World!"); // String operations let len = s2.len(); let is_empty = s2.is_empty(); let contains = s2.contains("World"); // String concatenation let s3 = String::from("Hello"); let s4 = String::from(" World"); let s5 = s3 + &s4; // s3 is moved, can't use it anymore // Format macro (doesn't take ownership) let s6 = format!("{} {}", s1, s4); // String slicing let hello = &s2[0..5]; let world = &s2[7..12]; // Iterate over chars for c in "Hello".chars() { println!("{}", c); } // Iterate over bytes for b in "Hello".bytes() { println!("{}", b); } // String to number let num: i32 = "42".parse().unwrap(); }
Ownership and Borrowing
Ownership Rules
- Each value in Rust has a variable called its owner
- There can only be one owner at a time
- When the owner goes out of scope, the value is dropped
fn main() { // Move (ownership transfer) let s1 = String::from("hello"); let s2 = s1; // s1 is no longer valid // println!("{}", s1); // Error! println!("{}", s2); // OK // Clone (deep copy) let s3 = String::from("hello"); let s4 = s3.clone(); println!("{} {}", s3, s4); // Both valid // Copy trait (stack-only data) let x = 5; let y = x; // x is still valid (Copy trait) println!("{} {}", x, y); }
References and Borrowing
fn main() { // Immutable reference (borrowing) let s1 = String::from("hello"); let len = calculate_length(&s1); // Borrow println!("{} has length {}", s1, len); // s1 still valid // Mutable reference let mut s = String::from("hello"); change(&mut s); println!("{}", s); // "hello, world" // Rules: // 1. Multiple immutable references OR one mutable reference // 2. References must always be valid let r1 = &s; // OK let r2 = &s; // OK // let r3 = &mut s; // Error! Can't have mutable while immutable exists println!("{} {}", r1, r2); // r1 and r2 no longer used after this let r3 = &mut s; // OK now } fn calculate_length(s: &String) -> usize { s.len() } fn change(s: &mut String) { s.push_str(", world"); }
Lifetimes
// Lifetime annotations fn longest<'a>(x: &'a str, y: &'a str) -> &'a str { if x.len() > y.len() { x } else { y } } // Struct with lifetime struct ImportantExcerpt<'a> { part: &'a str, } impl<'a> ImportantExcerpt<'a> { fn level(&self) -> i32 { 3 } } fn main() { let string1 = String::from("long string"); let result; { let string2 = String::from("short"); result = longest(string1.as_str(), string2.as_str()); println!("Longest: {}", result); } // result not valid here (string2 dropped) }
Data Structures
Arrays and Vectors
fn main() { // Array (fixed size) let arr: [i32; 5] = [1, 2, 3, 4, 5]; let arr2 = [3; 5]; // [3, 3, 3, 3, 3] let first = arr[0]; let len = arr.len(); // Vector (dynamic array) let mut vec = Vec::new(); vec.push(1); vec.push(2); vec.push(3); // Vec macro let vec2 = vec![1, 2, 3, 4, 5]; // Accessing elements let third = &vec2[2]; let third = vec2.get(2); // Returns Option<&T> // Iterate for i in &vec2 { println!("{}", i); } // Mutable iteration let mut vec3 = vec![1, 2, 3]; for i in &mut vec3 { *i += 50; } // Vector with enum for multiple types enum SpreadsheetCell { Int(i32), Float(f64), Text(String), } let row = vec![ SpreadsheetCell::Int(3), SpreadsheetCell::Float(10.12), SpreadsheetCell::Text(String::from("blue")), ]; }
HashMap
use std::collections::HashMap; fn main() { // Create HashMap let mut scores = HashMap::new(); scores.insert(String::from("Blue"), 10); scores.insert(String::from("Yellow"), 50); // From vectors let teams = vec![String::from("Blue"), String::from("Yellow")]; let initial_scores = vec![10, 50]; let scores: HashMap<_, _> = teams.iter().zip(initial_scores.iter()).collect(); // Accessing values let team_name = String::from("Blue"); let score = scores.get(&team_name); // Returns Option<&V> // Iterate for (key, value) in &scores { println!("{}: {}", key, value); } // Update values scores.insert(String::from("Blue"), 25); // Overwrite // Only insert if key doesn't exist scores.entry(String::from("Blue")).or_insert(50); // Update based on old value let text = "hello world wonderful world"; let mut map = HashMap::new(); for word in text.split_whitespace() { let count = map.entry(word).or_insert(0); *count += 1; } println!("{:?}", map); }
Control Flow
If-Else
fn main() { let number = 6; if number % 4 == 0 { println!("divisible by 4"); } else if number % 3 == 0 { println!("divisible by 3"); } else { println!("not divisible by 4 or 3"); } // If in let statement let condition = true; let number = if condition { 5 } else { 6 }; }
Loops
fn main() { // Loop (infinite) let mut count = 0; let result = loop { count += 1; if count == 10 { break count * 2; // Return value } }; // While loop let mut number = 3; while number != 0 { println!("{}!", number); number -= 1; } // For loop let arr = [10, 20, 30, 40, 50]; for element in arr.iter() { println!("{}", element); } // Range for number in 1..4 { println!("{}", number); // 1, 2, 3 } // Reverse range for number in (1..4).rev() { println!("{}", number); // 3, 2, 1 } // Enumerate for (i, v) in arr.iter().enumerate() { println!("{}: {}", i, v); } }
Match
fn main() { // Basic match let number = 3; match number { 1 => println!("One"), 2 => println!("Two"), 3 => println!("Three"), _ => println!("Other"), // Default case } // Match with return value let result = match number { 1 => "one", 2 => "two", 3 => "three", _ => "other", }; // Match ranges match number { 1..=5 => println!("1 through 5"), _ => println!("something else"), } // Match Option let some_value: Option<i32> = Some(3); match some_value { Some(i) => println!("Got {}", i), None => println!("Got nothing"), } // if let (concise match) if let Some(i) = some_value { println!("{}", i); } // Match guards let num = Some(4); match num { Some(x) if x < 5 => println!("less than five: {}", x), Some(x) => println!("{}", x), None => (), } }
Structs and Enums
Structs
// Define struct struct User { username: String, email: String, sign_in_count: u64, active: bool, } // Tuple struct struct Color(i32, i32, i32); struct Point(i32, i32, i32); // Unit struct (no fields) struct AlwaysEqual; impl User { // Associated function (constructor) fn new(username: String, email: String) -> User { User { username, email, sign_in_count: 1, active: true, } } // Method fn is_active(&self) -> bool { self.active } // Mutable method fn deactivate(&mut self) { self.active = false; } } fn main() { // Create instance let mut user1 = User { email: String::from("user@example.com"), username: String::from("user123"), active: true, sign_in_count: 1, }; user1.email = String::from("newemail@example.com"); // Struct update syntax let user2 = User { email: String::from("another@example.com"), ..user1 // Rest from user1 }; // Tuple struct let black = Color(0, 0, 0); let origin = Point(0, 0, 0); }
Enums
// Define enum enum IpAddr { V4(u8, u8, u8, u8), V6(String), } enum Message { Quit, Move { x: i32, y: i32 }, Write(String), ChangeColor(i32, i32, i32), } impl Message { fn call(&self) { match self { Message::Quit => println!("Quit"), Message::Move { x, y } => println!("Move to {}, {}", x, y), Message::Write(text) => println!("Write: {}", text), Message::ChangeColor(r, g, b) => println!("Color: {}, {}, {}", r, g, b), } } } fn main() { let home = IpAddr::V4(127, 0, 0, 1); let loopback = IpAddr::V6(String::from("::1")); let msg = Message::Write(String::from("hello")); msg.call(); }
Option and Result
fn main() { // Option<T> - value or nothing let some_number: Option<i32> = Some(5); let no_number: Option<i32> = None; // Match on Option match some_number { Some(i) => println!("{}", i), None => println!("nothing"), } // Unwrap (panics if None) let x = Some(5); let y = x.unwrap(); // Unwrap with default let z = no_number.unwrap_or(0); // Result<T, E> - success or error use std::fs::File; use std::io::ErrorKind; let f = File::open("hello.txt"); let f = match f { Ok(file) => file, Err(error) => match error.kind() { ErrorKind::NotFound => match File::create("hello.txt") { Ok(fc) => fc, Err(e) => panic!("Problem creating file: {:?}", e), }, other_error => panic!("Problem opening file: {:?}", other_error), }, }; // Propagating errors with ? fn read_username() -> Result<String, std::io::Error> { let mut f = File::open("hello.txt")?; let mut s = String::new(); use std::io::Read; f.read_to_string(&mut s)?; Ok(s) } }
Traits
// Define trait trait Summary { fn summarize(&self) -> String; // Default implementation fn default_summary(&self) -> String { String::from("(Read more...)") } } // Implement trait struct NewsArticle { headline: String, location: String, author: String, } impl Summary for NewsArticle { fn summarize(&self) -> String { format!("{}, by {} ({})", self.headline, self.author, self.location) } } struct Tweet { username: String, content: String, } impl Summary for Tweet { fn summarize(&self) -> String { format!("{}: {}", self.username, self.content) } } // Trait as parameter fn notify(item: &impl Summary) { println!("Breaking news! {}", item.summarize()); } // Trait bound syntax fn notify2<T: Summary>(item: &T) { println!("{}", item.summarize()); } // Multiple traits fn notify3<T: Summary + Display>(item: &T) { // ... } // Where clause fn some_function<T, U>(t: &T, u: &U) -> i32 where T: Display + Clone, U: Clone + Debug, { // ... 0 } // Return trait fn returns_summarizable() -> impl Summary { Tweet { username: String::from("user"), content: String::from("content"), } } use std::fmt::Display; use std::fmt::Debug; fn main() { let tweet = Tweet { username: String::from("user"), content: String::from("Hello, world!"), }; println!("{}", tweet.summarize()); }
Error Handling
use std::fs::File; use std::io::{self, Read}; // Propagating errors fn read_username_from_file() -> Result<String, io::Error> { let mut f = File::open("username.txt")?; let mut s = String::new(); f.read_to_string(&mut s)?; Ok(s) } // Custom error types use std::fmt; #[derive(Debug)] enum CustomError { IoError(io::Error), ParseError, } impl fmt::Display for CustomError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { CustomError::IoError(e) => write!(f, "IO error: {}", e), CustomError::ParseError => write!(f, "Parse error"), } } } impl From<io::Error> for CustomError { fn from(err: io::Error) -> CustomError { CustomError::IoError(err) } } // Panic fn will_panic() { panic!("crash and burn"); } // Assert fn check_value(x: i32) { assert!(x > 0, "x must be positive"); assert_eq!(x, 5); assert_ne!(x, 0); } fn main() { // Result with match match read_username_from_file() { Ok(username) => println!("Username: {}", username), Err(e) => println!("Error: {}", e), } // Unwrap let f = File::open("hello.txt").unwrap(); // Expect (with custom message) let f = File::open("hello.txt").expect("Failed to open file"); }
Generics
// Generic function fn largest<T: PartialOrd>(list: &[T]) -> &T { let mut largest = &list[0]; for item in list { if item > largest { largest = item; } } largest } // Generic struct struct Point<T> { x: T, y: T, } impl<T> Point<T> { fn x(&self) -> &T { &self.x } } // Implement for specific type impl Point<f32> { fn distance_from_origin(&self) -> f32 { (self.x.powi(2) + self.y.powi(2)).sqrt() } } // Multiple generic types struct Point2<T, U> { x: T, y: U, } // Generic enum enum Option<T> { Some(T), None, } enum Result<T, E> { Ok(T), Err(E), } fn main() { let numbers = vec![34, 50, 25, 100, 65]; let result = largest(&numbers); let integer = Point { x: 5, y: 10 }; let float = Point { x: 1.0, y: 4.0 }; let mixed = Point2 { x: 5, y: 4.0 }; }
Concurrency
Threads
use std::thread; use std::time::Duration; fn main() { // Spawn thread let handle = thread::spawn(|| { for i in 1..10 { println!("spawned thread: {}", i); thread::sleep(Duration::from_millis(1)); } }); // Wait for thread handle.join().unwrap(); // Move closure let v = vec![1, 2, 3]; let handle = thread::spawn(move || { println!("vector: {:?}", v); }); handle.join().unwrap(); }
Channels
use std::sync::mpsc; use std::thread; fn main() { // Create channel let (tx, rx) = mpsc::channel(); thread::spawn(move || { let val = String::from("hi"); tx.send(val).unwrap(); }); let received = rx.recv().unwrap(); println!("Got: {}", received); // Multiple producers let (tx, rx) = mpsc::channel(); let tx1 = tx.clone(); thread::spawn(move || { tx.send(String::from("hi from thread 1")).unwrap(); }); thread::spawn(move || { tx1.send(String::from("hi from thread 2")).unwrap(); }); for received in rx { println!("Got: {}", received); } }
Shared State
use std::sync::{Arc, Mutex}; use std::thread; fn main() { // Mutex for mutual exclusion let counter = Arc::new(Mutex::new(0)); let mut handles = vec![]; for _ in 0..10 { let counter = Arc::clone(&counter); let handle = thread::spawn(move || { let mut num = counter.lock().unwrap(); *num += 1; }); handles.push(handle); } for handle in handles { handle.join().unwrap(); } println!("Result: {}", *counter.lock().unwrap()); }
Common Patterns
Builder Pattern
#[derive(Default)] struct User { username: String, email: String, age: Option<u32>, } impl User { fn builder() -> UserBuilder { UserBuilder::default() } } #[derive(Default)] struct UserBuilder { username: String, email: String, age: Option<u32>, } impl UserBuilder { fn username(mut self, username: &str) -> Self { self.username = username.to_string(); self } fn email(mut self, email: &str) -> Self { self.email = email.to_string(); self } fn age(mut self, age: u32) -> Self { self.age = Some(age); self } fn build(self) -> User { User { username: self.username, email: self.email, age: self.age, } } } fn main() { let user = User::builder() .username("alice") .email("alice@example.com") .age(30) .build(); }
Newtype Pattern
#![allow(unused)] fn main() { // Wrap existing type struct Wrapper(Vec<String>); impl Wrapper { fn new() -> Self { Wrapper(Vec::new()) } fn push(&mut self, s: String) { self.0.push(s); } } }
Testing
#![allow(unused)] fn main() { // Unit tests #[cfg(test)] mod tests { use super::*; #[test] fn it_works() { assert_eq!(2 + 2, 4); } #[test] fn it_adds() { assert_eq!(add(2, 2), 4); } #[test] #[should_panic] fn it_panics() { panic!("panic!"); } #[test] fn it_returns_result() -> Result<(), String> { if 2 + 2 == 4 { Ok(()) } else { Err(String::from("two plus two does not equal four")) } } } fn add(a: i32, b: i32) -> i32 { a + b } }
Cargo Commands
# Create new project
cargo new project_name
cargo new --lib library_name
# Build project
cargo build
cargo build --release
# Run project
cargo run
# Check code
cargo check
# Run tests
cargo test
# Generate documentation
cargo doc --open
# Update dependencies
cargo update
# Format code
cargo fmt
# Lint code
cargo clippy
Best Practices
- Use ownership properly - Avoid unnecessary clones
- Handle errors with Result - Don't unwrap in production
- Use iterators - More efficient and idiomatic
- Prefer
&stroverStringfor function parameters - Use
Optioninstead of null - Implement
Debugtrait for custom types - Use pattern matching instead of if chains
- Follow naming conventions - snake_case for variables/functions
- Write tests -
cargo test - Use clippy -
cargo clippyfor linting
Common Libraries
- serde: Serialization/deserialization
- tokio: Async runtime
- reqwest: HTTP client
- actix-web: Web framework
- diesel: ORM
- clap: CLI argument parsing
- log: Logging facade
- anyhow: Error handling
- thiserror: Custom error types
SQL (Structured Query Language)
Overview
SQL is a domain-specific language used for managing and manipulating relational databases. It's the standard language for relational database management systems (RDBMS).
Key Concepts:
- Declarative language (what, not how)
- ACID properties (Atomicity, Consistency, Isolation, Durability)
- Set-based operations
- Data definition, manipulation, and querying
- Transaction management
Popular Database Systems:
- PostgreSQL
- MySQL/MariaDB
- Oracle Database
- Microsoft SQL Server
- SQLite
Basic Syntax
Data Types
-- Numeric
INT, INTEGER -- Whole numbers
SMALLINT, BIGINT -- Different sizes
DECIMAL(10, 2), NUMERIC -- Fixed-point numbers
FLOAT, REAL, DOUBLE -- Floating-point numbers
-- String
CHAR(10) -- Fixed length
VARCHAR(255) -- Variable length
TEXT -- Long text
-- Date and Time
DATE -- Date only
TIME -- Time only
TIMESTAMP, DATETIME -- Date and time
YEAR -- Year only
-- Boolean
BOOLEAN, BOOL -- True/False
-- Binary
BLOB -- Binary large object
BYTEA (PostgreSQL) -- Binary data
-- Other
JSON, JSONB (PostgreSQL) -- JSON data
UUID -- Universally unique identifier
ENUM('small', 'medium') -- Enumerated type
DDL (Data Definition Language)
CREATE
-- Create database
CREATE DATABASE mydb;
CREATE DATABASE IF NOT EXISTS mydb;
-- Use database
USE mydb;
-- Create table
CREATE TABLE users (
id INT PRIMARY KEY AUTO_INCREMENT,
username VARCHAR(50) NOT NULL UNIQUE,
email VARCHAR(100) NOT NULL UNIQUE,
password VARCHAR(255) NOT NULL,
age INT CHECK (age >= 18),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
);
-- Create table with foreign key
CREATE TABLE posts (
id INT PRIMARY KEY AUTO_INCREMENT,
user_id INT NOT NULL,
title VARCHAR(255) NOT NULL,
content TEXT,
published BOOLEAN DEFAULT FALSE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
);
-- Create table with composite primary key
CREATE TABLE user_roles (
user_id INT,
role_id INT,
granted_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (user_id, role_id),
FOREIGN KEY (user_id) REFERENCES users(id),
FOREIGN KEY (role_id) REFERENCES roles(id)
);
-- Create table from query
CREATE TABLE archived_users AS
SELECT * FROM users WHERE created_at < '2020-01-01';
ALTER
-- Add column
ALTER TABLE users ADD COLUMN phone VARCHAR(20);
-- Modify column
ALTER TABLE users MODIFY COLUMN email VARCHAR(150);
ALTER TABLE users ALTER COLUMN age SET DEFAULT 18;
-- Rename column
ALTER TABLE users RENAME COLUMN username TO user_name;
-- Drop column
ALTER TABLE users DROP COLUMN phone;
-- Add constraint
ALTER TABLE users ADD CONSTRAINT chk_age CHECK (age >= 18);
ALTER TABLE users ADD UNIQUE (email);
-- Drop constraint
ALTER TABLE users DROP CONSTRAINT chk_age;
-- Rename table
ALTER TABLE users RENAME TO customers;
DROP
-- Drop table
DROP TABLE users;
DROP TABLE IF EXISTS users;
-- Drop database
DROP DATABASE mydb;
DROP DATABASE IF EXISTS mydb;
-- Truncate (delete all rows, keep structure)
TRUNCATE TABLE users;
DML (Data Manipulation Language)
INSERT
-- Insert single row
INSERT INTO users (username, email, password, age)
VALUES ('alice', 'alice@example.com', 'hashed_pwd', 30);
-- Insert multiple rows
INSERT INTO users (username, email, password, age)
VALUES
('bob', 'bob@example.com', 'hashed_pwd', 25),
('charlie', 'charlie@example.com', 'hashed_pwd', 35),
('diana', 'diana@example.com', 'hashed_pwd', 28);
-- Insert from select
INSERT INTO archived_users
SELECT * FROM users WHERE created_at < '2020-01-01';
-- Insert or update (MySQL)
INSERT INTO users (id, username, email)
VALUES (1, 'alice', 'alice@example.com')
ON DUPLICATE KEY UPDATE email = VALUES(email);
-- Insert or ignore (MySQL)
INSERT IGNORE INTO users (username, email)
VALUES ('alice', 'alice@example.com');
-- Upsert (PostgreSQL)
INSERT INTO users (id, username, email)
VALUES (1, 'alice', 'alice@example.com')
ON CONFLICT (id) DO UPDATE SET email = EXCLUDED.email;
UPDATE
-- Update single row
UPDATE users
SET email = 'newemail@example.com'
WHERE id = 1;
-- Update multiple columns
UPDATE users
SET email = 'alice@newdomain.com',
age = 31,
updated_at = CURRENT_TIMESTAMP
WHERE username = 'alice';
-- Update with condition
UPDATE users
SET age = age + 1
WHERE created_at < '2020-01-01';
-- Update from join
UPDATE users u
INNER JOIN orders o ON u.id = o.user_id
SET u.total_orders = (
SELECT COUNT(*) FROM orders WHERE user_id = u.id
)
WHERE o.created_at > '2024-01-01';
-- Update all rows (dangerous!)
UPDATE users SET active = TRUE;
DELETE
-- Delete specific row
DELETE FROM users WHERE id = 1;
-- Delete with condition
DELETE FROM users WHERE created_at < '2020-01-01';
-- Delete with join (MySQL)
DELETE u FROM users u
INNER JOIN orders o ON u.id = o.user_id
WHERE o.status = 'cancelled';
-- Delete all rows (dangerous!)
DELETE FROM users;
-- Soft delete (recommended)
UPDATE users SET deleted_at = CURRENT_TIMESTAMP WHERE id = 1;
DQL (Data Query Language)
SELECT
-- Select all columns
SELECT * FROM users;
-- Select specific columns
SELECT username, email FROM users;
-- Select with alias
SELECT username AS name, email AS contact FROM users;
-- Select with calculation
SELECT
username,
age,
YEAR(CURRENT_DATE) - YEAR(created_at) AS years_member
FROM users;
-- Select distinct
SELECT DISTINCT age FROM users;
-- Select with limit
SELECT * FROM users LIMIT 10;
SELECT * FROM users LIMIT 10 OFFSET 20; -- Skip first 20
-- Select top (SQL Server)
SELECT TOP 10 * FROM users;
WHERE
-- Basic conditions
SELECT * FROM users WHERE age > 25;
SELECT * FROM users WHERE username = 'alice';
SELECT * FROM users WHERE age >= 18 AND age <= 65;
-- IN operator
SELECT * FROM users WHERE age IN (25, 30, 35);
SELECT * FROM users WHERE username IN ('alice', 'bob', 'charlie');
-- BETWEEN
SELECT * FROM users WHERE age BETWEEN 18 AND 65;
SELECT * FROM users WHERE created_at BETWEEN '2023-01-01' AND '2023-12-31';
-- LIKE (pattern matching)
SELECT * FROM users WHERE email LIKE '%@gmail.com';
SELECT * FROM users WHERE username LIKE 'a%'; -- Starts with 'a'
SELECT * FROM users WHERE username LIKE '%a'; -- Ends with 'a'
SELECT * FROM users WHERE username LIKE '%a%'; -- Contains 'a'
SELECT * FROM users WHERE username LIKE 'a_b'; -- _ matches single char
-- IS NULL / IS NOT NULL
SELECT * FROM users WHERE phone IS NULL;
SELECT * FROM users WHERE phone IS NOT NULL;
-- NOT
SELECT * FROM users WHERE NOT age > 25;
SELECT * FROM users WHERE age NOT IN (25, 30, 35);
-- Combining conditions
SELECT * FROM users
WHERE (age > 25 OR username LIKE 'a%')
AND email IS NOT NULL;
ORDER BY
-- Sort ascending
SELECT * FROM users ORDER BY age;
SELECT * FROM users ORDER BY age ASC;
-- Sort descending
SELECT * FROM users ORDER BY age DESC;
-- Sort by multiple columns
SELECT * FROM users ORDER BY age DESC, username ASC;
-- Sort by calculated column
SELECT username, age * 2 AS double_age
FROM users
ORDER BY double_age DESC;
-- Sort with NULL handling
SELECT * FROM users ORDER BY phone NULLS FIRST;
SELECT * FROM users ORDER BY phone NULLS LAST;
GROUP BY
-- Count users by age
SELECT age, COUNT(*) as count
FROM users
GROUP BY age;
-- Multiple aggregations
SELECT
age,
COUNT(*) as count,
AVG(age) as avg_age,
MIN(age) as min_age,
MAX(age) as max_age
FROM users
GROUP BY age;
-- Group by multiple columns
SELECT
age,
YEAR(created_at) as year,
COUNT(*) as count
FROM users
GROUP BY age, YEAR(created_at);
-- HAVING (filter groups)
SELECT age, COUNT(*) as count
FROM users
GROUP BY age
HAVING COUNT(*) > 5;
-- GROUP BY with ORDER BY
SELECT age, COUNT(*) as count
FROM users
GROUP BY age
HAVING COUNT(*) > 5
ORDER BY count DESC;
Joins
-- INNER JOIN (only matching rows)
SELECT u.username, p.title
FROM users u
INNER JOIN posts p ON u.id = p.user_id;
-- LEFT JOIN (all from left, matching from right)
SELECT u.username, p.title
FROM users u
LEFT JOIN posts p ON u.id = p.user_id;
-- RIGHT JOIN (all from right, matching from left)
SELECT u.username, p.title
FROM users u
RIGHT JOIN posts p ON u.id = p.user_id;
-- FULL OUTER JOIN (all from both)
SELECT u.username, p.title
FROM users u
FULL OUTER JOIN posts p ON u.id = p.user_id;
-- CROSS JOIN (Cartesian product)
SELECT u.username, r.role_name
FROM users u
CROSS JOIN roles r;
-- Self join
SELECT
e1.name as employee,
e2.name as manager
FROM employees e1
LEFT JOIN employees e2 ON e1.manager_id = e2.id;
-- Multiple joins
SELECT
u.username,
p.title,
c.content as comment
FROM users u
INNER JOIN posts p ON u.id = p.user_id
INNER JOIN comments c ON p.id = c.post_id;
-- Join with conditions
SELECT u.username, p.title
FROM users u
LEFT JOIN posts p ON u.id = p.user_id AND p.published = TRUE;
Subqueries
-- Subquery in WHERE
SELECT username FROM users
WHERE id IN (
SELECT user_id FROM orders WHERE total > 100
);
-- Subquery in SELECT
SELECT
username,
(SELECT COUNT(*) FROM posts WHERE user_id = users.id) as post_count
FROM users;
-- Subquery in FROM
SELECT avg_age FROM (
SELECT AVG(age) as avg_age FROM users GROUP BY city
) as subquery;
-- Correlated subquery
SELECT username FROM users u
WHERE age > (
SELECT AVG(age) FROM users WHERE city = u.city
);
-- EXISTS
SELECT username FROM users u
WHERE EXISTS (
SELECT 1 FROM orders WHERE user_id = u.id
);
-- NOT EXISTS
SELECT username FROM users u
WHERE NOT EXISTS (
SELECT 1 FROM orders WHERE user_id = u.id
);
-- ANY / ALL
SELECT username FROM users
WHERE age > ANY (SELECT age FROM users WHERE city = 'NYC');
SELECT username FROM users
WHERE age > ALL (SELECT age FROM users WHERE city = 'NYC');
Aggregate Functions
-- COUNT
SELECT COUNT(*) FROM users;
SELECT COUNT(DISTINCT age) FROM users;
-- SUM
SELECT SUM(total) FROM orders;
-- AVG
SELECT AVG(age) FROM users;
-- MIN / MAX
SELECT MIN(age), MAX(age) FROM users;
-- String aggregation (PostgreSQL)
SELECT STRING_AGG(username, ', ') FROM users;
-- GROUP_CONCAT (MySQL)
SELECT GROUP_CONCAT(username SEPARATOR ', ') FROM users;
-- Combined
SELECT
COUNT(*) as total_users,
AVG(age) as average_age,
MIN(age) as youngest,
MAX(age) as oldest,
SUM(CASE WHEN age >= 18 THEN 1 ELSE 0 END) as adults
FROM users;
Common Table Expressions (CTE)
-- Basic CTE
WITH active_users AS (
SELECT * FROM users WHERE active = TRUE
)
SELECT * FROM active_users WHERE age > 25;
-- Multiple CTEs
WITH
active_users AS (
SELECT * FROM users WHERE active = TRUE
),
user_posts AS (
SELECT user_id, COUNT(*) as post_count
FROM posts
GROUP BY user_id
)
SELECT
au.username,
up.post_count
FROM active_users au
LEFT JOIN user_posts up ON au.id = up.user_id;
-- Recursive CTE (hierarchy)
WITH RECURSIVE employee_hierarchy AS (
-- Base case
SELECT id, name, manager_id, 1 as level
FROM employees
WHERE manager_id IS NULL
UNION ALL
-- Recursive case
SELECT e.id, e.name, e.manager_id, eh.level + 1
FROM employees e
INNER JOIN employee_hierarchy eh ON e.manager_id = eh.id
)
SELECT * FROM employee_hierarchy ORDER BY level;
Window Functions
-- ROW_NUMBER
SELECT
username,
age,
ROW_NUMBER() OVER (ORDER BY age DESC) as row_num
FROM users;
-- RANK / DENSE_RANK
SELECT
username,
score,
RANK() OVER (ORDER BY score DESC) as rank,
DENSE_RANK() OVER (ORDER BY score DESC) as dense_rank
FROM users;
-- Partition by
SELECT
city,
username,
age,
AVG(age) OVER (PARTITION BY city) as avg_city_age
FROM users;
-- Running total
SELECT
date,
amount,
SUM(amount) OVER (ORDER BY date) as running_total
FROM sales;
-- LAG / LEAD (previous/next row)
SELECT
date,
revenue,
LAG(revenue) OVER (ORDER BY date) as prev_revenue,
LEAD(revenue) OVER (ORDER BY date) as next_revenue
FROM daily_sales;
-- NTILE (divide into buckets)
SELECT
username,
score,
NTILE(4) OVER (ORDER BY score DESC) as quartile
FROM users;
Indexes
-- Create index
CREATE INDEX idx_users_email ON users(email);
-- Create unique index
CREATE UNIQUE INDEX idx_users_username ON users(username);
-- Create composite index
CREATE INDEX idx_users_age_city ON users(age, city);
-- Create partial index (PostgreSQL)
CREATE INDEX idx_active_users ON users(username)
WHERE active = TRUE;
-- Create index with condition (filtered index - SQL Server)
CREATE INDEX idx_active_users ON users(username)
WHERE active = 1;
-- Full-text index (MySQL)
CREATE FULLTEXT INDEX idx_posts_content ON posts(title, content);
-- Drop index
DROP INDEX idx_users_email ON users;
-- Show indexes
SHOW INDEX FROM users; -- MySQL
SELECT * FROM pg_indexes WHERE tablename = 'users'; -- PostgreSQL
Transactions
-- Start transaction
BEGIN;
START TRANSACTION;
-- Commit transaction
COMMIT;
-- Rollback transaction
ROLLBACK;
-- Example transaction
BEGIN;
UPDATE accounts SET balance = balance - 100 WHERE id = 1;
UPDATE accounts SET balance = balance + 100 WHERE id = 2;
-- Check if everything is okay
IF (SELECT balance FROM accounts WHERE id = 1) >= 0 THEN
COMMIT;
ELSE
ROLLBACK;
END IF;
-- Savepoint
BEGIN;
UPDATE users SET age = 30 WHERE id = 1;
SAVEPOINT my_savepoint;
UPDATE users SET age = 40 WHERE id = 2;
ROLLBACK TO SAVEPOINT my_savepoint; -- Only rollback second update
COMMIT;
-- Transaction isolation levels
SET TRANSACTION ISOLATION LEVEL READ UNCOMMITTED;
SET TRANSACTION ISOLATION LEVEL READ COMMITTED;
SET TRANSACTION ISOLATION LEVEL REPEATABLE READ;
SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;
Views
-- Create view
CREATE VIEW active_users AS
SELECT id, username, email
FROM users
WHERE active = TRUE;
-- Use view
SELECT * FROM active_users;
-- Create or replace view
CREATE OR REPLACE VIEW user_stats AS
SELECT
u.id,
u.username,
COUNT(p.id) as post_count,
COUNT(DISTINCT c.id) as comment_count
FROM users u
LEFT JOIN posts p ON u.id = p.user_id
LEFT JOIN comments c ON u.id = c.user_id
GROUP BY u.id, u.username;
-- Materialized view (PostgreSQL)
CREATE MATERIALIZED VIEW user_stats_mv AS
SELECT
u.id,
COUNT(p.id) as post_count
FROM users u
LEFT JOIN posts p ON u.id = p.user_id
GROUP BY u.id;
-- Refresh materialized view
REFRESH MATERIALIZED VIEW user_stats_mv;
-- Drop view
DROP VIEW active_users;
DROP MATERIALIZED VIEW user_stats_mv;
Stored Procedures and Functions
-- MySQL stored procedure
DELIMITER //
CREATE PROCEDURE GetUsersByAge(IN min_age INT)
BEGIN
SELECT * FROM users WHERE age >= min_age;
END //
DELIMITER ;
-- Call procedure
CALL GetUsersByAge(25);
-- Function (MySQL)
DELIMITER //
CREATE FUNCTION CalculateAge(birth_date DATE)
RETURNS INT
DETERMINISTIC
BEGIN
RETURN YEAR(CURRENT_DATE) - YEAR(birth_date);
END //
DELIMITER ;
-- Use function
SELECT username, CalculateAge(birth_date) as age FROM users;
-- PostgreSQL function
CREATE OR REPLACE FUNCTION get_user_count()
RETURNS INTEGER AS $$
BEGIN
RETURN (SELECT COUNT(*) FROM users);
END;
$$ LANGUAGE plpgsql;
-- Call function
SELECT get_user_count();
-- Drop procedure/function
DROP PROCEDURE GetUsersByAge;
DROP FUNCTION CalculateAge;
Common Patterns
Pagination
-- Offset pagination
SELECT * FROM users
ORDER BY id
LIMIT 10 OFFSET 20; -- Page 3 (0-based)
-- Cursor-based pagination (more efficient)
SELECT * FROM users
WHERE id > 100 -- Last seen ID
ORDER BY id
LIMIT 10;
Finding Duplicates
-- Find duplicate emails
SELECT email, COUNT(*) as count
FROM users
GROUP BY email
HAVING COUNT(*) > 1;
-- Get duplicate rows with details
SELECT u.*
FROM users u
INNER JOIN (
SELECT email FROM users
GROUP BY email
HAVING COUNT(*) > 1
) dup ON u.email = dup.email;
Ranking
-- Top N per group
WITH ranked AS (
SELECT
*,
ROW_NUMBER() OVER (PARTITION BY category ORDER BY score DESC) as rn
FROM products
)
SELECT * FROM ranked WHERE rn <= 3;
Running Totals
-- Running total
SELECT
date,
revenue,
SUM(revenue) OVER (ORDER BY date) as cumulative_revenue
FROM sales
ORDER BY date;
Pivot Table
-- MySQL
SELECT
username,
SUM(CASE WHEN YEAR(created_at) = 2023 THEN 1 ELSE 0 END) as year_2023,
SUM(CASE WHEN YEAR(created_at) = 2024 THEN 1 ELSE 0 END) as year_2024
FROM users
GROUP BY username;
-- PostgreSQL (crosstab)
SELECT * FROM crosstab(
'SELECT username, YEAR(created_at), COUNT(*) FROM users GROUP BY 1, 2',
'SELECT DISTINCT YEAR(created_at) FROM users ORDER BY 1'
) AS ct(username TEXT, year_2023 INT, year_2024 INT);
Performance Optimization
Best Practices
-
Use indexes wisely
-- Index columns used in WHERE, JOIN, ORDER BY CREATE INDEX idx_users_email ON users(email); -
**Avoid SELECT ***
-- Bad SELECT * FROM users; -- Good SELECT id, username, email FROM users; -
Use LIMIT
SELECT * FROM users LIMIT 100; -
Use JOIN instead of subqueries when possible
-- Slower SELECT * FROM users WHERE id IN (SELECT user_id FROM orders); -- Faster SELECT DISTINCT u.* FROM users u INNER JOIN orders o ON u.id = o.user_id; -
Use EXPLAIN to analyze queries
EXPLAIN SELECT * FROM users WHERE email = 'alice@example.com'; EXPLAIN ANALYZE SELECT * FROM users WHERE age > 25; -
Avoid functions on indexed columns
-- Bad (can't use index) SELECT * FROM users WHERE YEAR(created_at) = 2024; -- Good SELECT * FROM users WHERE created_at >= '2024-01-01' AND created_at < '2025-01-01'; -
Use covering indexes
CREATE INDEX idx_users_email_username ON users(email, username); -- This query can be satisfied entirely from the index SELECT username FROM users WHERE email = 'alice@example.com';
Common Functions
String Functions
-- CONCAT
SELECT CONCAT(first_name, ' ', last_name) as full_name FROM users;
-- UPPER / LOWER
SELECT UPPER(username), LOWER(email) FROM users;
-- LENGTH / CHAR_LENGTH
SELECT LENGTH(username), CHAR_LENGTH(username) FROM users;
-- SUBSTRING
SELECT SUBSTRING(email, 1, 10) FROM users;
-- TRIM
SELECT TRIM(username) FROM users;
-- REPLACE
SELECT REPLACE(email, '@gmail.com', '@newdomain.com') FROM users;
Date Functions
-- Current date/time
SELECT NOW(), CURRENT_DATE, CURRENT_TIME;
-- Date arithmetic
SELECT DATE_ADD(created_at, INTERVAL 30 DAY) FROM users;
SELECT DATE_SUB(created_at, INTERVAL 1 YEAR) FROM users;
-- Date difference
SELECT DATEDIFF(NOW(), created_at) as days_since_creation FROM users;
-- Date formatting
SELECT DATE_FORMAT(created_at, '%Y-%m-%d %H:%i:%s') FROM users;
-- Extract parts
SELECT
YEAR(created_at) as year,
MONTH(created_at) as month,
DAY(created_at) as day
FROM users;
Conditional Functions
-- CASE
SELECT
username,
CASE
WHEN age < 18 THEN 'Minor'
WHEN age >= 18 AND age < 65 THEN 'Adult'
ELSE 'Senior'
END as age_group
FROM users;
-- IF (MySQL)
SELECT IF(age >= 18, 'Adult', 'Minor') as status FROM users;
-- COALESCE (first non-null value)
SELECT COALESCE(phone, email, 'No contact') FROM users;
-- NULLIF (return NULL if equal)
SELECT NULLIF(age, 0) FROM users;
Security Best Practices
-
Use parameterized queries (prevent SQL injection)
-- Bad (vulnerable to SQL injection) SELECT * FROM users WHERE username = '$user_input'; -- Good (parameterized) SELECT * FROM users WHERE username = ?; -
Principle of least privilege
- Grant minimum necessary permissions
- Use separate accounts for different applications
-
Encrypt sensitive data
-- Store password hashes, never plain text INSERT INTO users (username, password_hash) VALUES ('alice', SHA2('password', 256)); -
Regular backups
mysqldump -u root -p database_name > backup.sql pg_dump database_name > backup.sql -
Input validation
- Validate and sanitize all user inputs
- Use constraints in database schema
Database-Specific Features
PostgreSQL
-- Array type
CREATE TABLE users (tags TEXT[]);
INSERT INTO users (tags) VALUES (ARRAY['admin', 'moderator']);
SELECT * FROM users WHERE 'admin' = ANY(tags);
-- JSON type
CREATE TABLE users (metadata JSONB);
INSERT INTO users (metadata) VALUES ('{"age": 30, "city": "NYC"}');
SELECT metadata->>'age' FROM users;
-- Generate series
SELECT * FROM generate_series(1, 10);
MySQL
-- Auto increment
CREATE TABLE users (
id INT AUTO_INCREMENT PRIMARY KEY
);
-- Full-text search
CREATE FULLTEXT INDEX ft_content ON posts(content);
SELECT * FROM posts WHERE MATCH(content) AGAINST('search term');
-- JSON functions
SELECT JSON_EXTRACT(metadata, '$.age') FROM users;
Common Database Tools
- MySQL Workbench: GUI for MySQL
- pgAdmin: GUI for PostgreSQL
- DBeaver: Universal database tool
- TablePlus: Modern database client
- DataGrip: JetBrains database IDE
Command Line Tools
# MySQL
mysql -u root -p
mysql -u root -p database_name < backup.sql
mysqldump -u root -p database_name > backup.sql
# PostgreSQL
psql -U postgres
psql -U postgres database_name < backup.sql
pg_dump database_name > backup.sql
# SQLite
sqlite3 database.db
.tables
.schema table_name
.quit
Interview Questions
LeetCode Patterns
LeetCode is a popular platform for practicing coding problems and preparing for technical interviews. Many problems on LeetCode can be categorized into specific patterns. Understanding these patterns can help you approach and solve problems more efficiently. Here are some common LeetCode patterns:
1. Sliding Window
The sliding window pattern is used to solve problems that involve a contiguous sequence of elements, such as subarrays or substrings. This pattern helps in reducing the time complexity by avoiding redundant calculations.
Example Problem: Find the maximum sum of a subarray of size k.
2. Two Pointers
The two pointers pattern is used to solve problems involving sorted arrays or linked lists. It involves using two pointers to iterate through the data structure, often from opposite ends, to find pairs or subarrays that meet certain criteria.
Example Problem: Find two numbers in a sorted array that add up to a given target.
3. Fast and Slow Pointers
The fast and slow pointers pattern is used to detect cycles in linked lists or arrays. The fast pointer moves twice as fast as the slow pointer, and if there is a cycle, they will eventually meet.
Example Problem: Detect a cycle in a linked list.
4. Merge Intervals
The merge intervals pattern is used to solve problems that involve overlapping intervals. This pattern helps in merging overlapping intervals and simplifying the problem.
Example Problem: Merge overlapping intervals in a list of intervals.
5. Cyclic Sort
The cyclic sort pattern is used to solve problems involving arrays where the elements are in a range from 1 to n. This pattern helps in placing each element at its correct index.
Example Problem: Find the missing number in an array of size n containing numbers from 1 to n.
6. In-place Reversal of a Linked List
The in-place reversal of a linked list pattern is used to solve problems that require reversing a portion of a linked list. This pattern helps in reversing the nodes of the linked list in-place without using extra space.
Example Problem: Reverse a sublist of a linked list from position m to n.
7. Tree BFS (Breadth-First Search)
The tree BFS pattern is used to solve problems involving tree traversal. This pattern helps in traversing the tree level by level and is useful for problems that require processing nodes in a specific order.
Example Problem: Find the level order traversal of a binary tree.
8. Tree DFS (Depth-First Search)
The tree DFS pattern is used to solve problems involving tree traversal. This pattern helps in traversing the tree depth-wise and is useful for problems that require exploring all paths from the root to the leaves.
Example Problem: Find all root-to-leaf paths in a binary tree.
9. Two Heaps
The two heaps pattern is used to solve problems that require finding the median of a data stream. This pattern involves using two heaps (a max-heap and a min-heap) to maintain the median efficiently.
Example Problem: Find the median of a data stream.
10. Subsets
The subsets pattern is used to solve problems that involve generating all possible subsets of a given set. This pattern helps in exploring all combinations and permutations of the elements.
Example Problem: Generate all subsets of a given set of numbers.
11. Modified Binary Search
The modified binary search pattern is used to solve problems that involve searching in a sorted array or matrix. This pattern helps in reducing the search space by half in each step.
Example Problem: Find the peak element in a mountain array.
12. Topological Sort
The topological sort pattern is used to solve problems involving directed acyclic graphs (DAGs). This pattern helps in ordering the nodes of the graph in a linear sequence based on their dependencies.
Example Problem: Find the order of courses to take given their prerequisites.
13. K-way Merge
The k-way merge pattern is used to solve problems that involve merging multiple sorted arrays or lists. This pattern helps in efficiently merging the arrays using a min-heap.
Example Problem: Merge k sorted linked lists.
14. Knapsack (Dynamic Programming)
The knapsack pattern is used to solve problems that involve selecting items with given weights and values to maximize the total value without exceeding a weight limit. This pattern helps in solving optimization problems using dynamic programming.
Example Problem: Find the maximum value that can be obtained by selecting items with given weights and values.
Understanding these patterns and practicing problems based on them can significantly improve your problem-solving skills and help you perform better in technical interviews.
15. Stacks
The stack pattern is used to solve problems that involve processing elements in a last-in, first-out (LIFO) order. This pattern helps in managing function calls, parsing expressions, and backtracking problems.
Example Problem: Evaluate a postfix expression.
Understanding these patterns and practicing problems based on them can significantly improve your problem-solving skills and help you perform better in technical interviews.
Design Patterns
// Start Generation Here
Introduction to Design Patterns
Design patterns are proven solutions to common problems encountered in software development. They provide a standardized approach to solving issues related to object creation, structure, and behavior, promoting code reusability, scalability, and maintainability. Understanding design patterns is essential for building robust and efficient software systems.
Categories of Design Patterns
Design patterns are typically categorized into three main types:
- Creational Patterns: Focus on object creation mechanisms, aiming to create objects in a manner suitable to the situation.
- Structural Patterns: Deal with object composition, identifying simple ways to realize relationships between objects.
- Behavioral Patterns: Concerned with communication between objects, highlighting patterns of interaction.
List of Design Patterns and Their Uses
Creational Patterns
-
Singleton: Ensures a class has only one instance and provides a global point of access to it. Useful for managing shared resources like logging or configuration settings.
-
Factory Method: Defines an interface for creating an object but lets subclasses alter the type of objects that will be created. Useful for creating objects without specifying the exact class of the object to be created.
-
Abstract Factory: Provides an interface for creating families of related or dependent objects without specifying their concrete classes. Useful when the system needs to be independent of how its objects are created.
-
Builder: Separates the construction of a complex object from its representation, allowing the same construction process to create various representations. Useful for constructing complex objects step by step.
-
Prototype: Specifies the kinds of objects to create using a prototypical instance and creates new objects by copying this prototype. Useful when object creation is expensive or complex.
Structural Patterns
-
Adapter: Allows incompatible interfaces to work together by converting the interface of one class into another expected by the clients. Useful when integrating legacy systems or third-party libraries.
-
Bridge: Decouples an abstraction from its implementation so that the two can vary independently. Useful for handling multiple implementations of an abstraction.
-
Composite: Composes objects into tree structures to represent part-whole hierarchies, allowing clients to treat individual objects and compositions uniformly. Useful for representing hierarchical structures like file systems.
-
Decorator: Adds additional responsibilities to an object dynamically without altering its structure. Useful for enhancing functionalities of objects without subclassing.
-
Facade: Provides a simplified interface to a complex subsystem, making the subsystem easier to use. Useful for reducing dependencies and simplifying client interaction.
-
Flyweight: Reduces the cost of creating and maintaining a large number of similar objects by sharing as much data as possible. Useful for handling large numbers of objects efficiently.
-
Proxy: Provides a surrogate or placeholder for another object to control access to it. Useful for lazy initialization, access control, or logging.
Behavioral Patterns
-
Chain of Responsibility: Passes a request along a chain of handlers, allowing each handler to process or pass it along. Useful for decoupling senders and receivers of requests.
-
Command: Encapsulates a request as an object, thereby allowing for parameterization and queuing of requests. Useful for implementing undoable operations or task scheduling.
-
Interpreter: Defines a representation for a language's grammar and interprets sentences in the language. Useful for parsing and interpreting expressions or languages.
-
Iterator: Provides a way to access elements of an aggregate object sequentially without exposing its underlying representation. Useful for traversing collections.
-
Mediator: Defines an object that encapsulates how a set of objects interact, promoting loose coupling by keeping objects from referring to each other explicitly. Useful for reducing complexity in object interactions.
-
Memento: Captures and externalizes an object's internal state without violating encapsulation, allowing the object to be restored to this state later. Useful for implementing undo functionality.
-
Observer: Defines a one-to-many dependency between objects so that when one object changes state, all its dependents are notified and updated automatically. Useful for event handling and implementing distributed event systems.
-
State: Allows an object to alter its behavior when its internal state changes, appearing as if it has changed its class. Useful for managing state-dependent behavior.
-
Strategy: Defines a family of algorithms, encapsulates each one, and makes them interchangeable, allowing the algorithm to vary independently from clients that use it. Useful for selecting algorithms dynamically at runtime.
-
Template Method: Defines the skeleton of an algorithm in a method, deferring some steps to subclasses, allowing subclasses to redefine certain steps without changing the algorithm's structure. Useful for implementing invariant parts of an algorithm and varying certain steps.
-
Visitor: Represents an operation to be performed on elements of an object structure, allowing new operations to be added without modifying the classes of the elements on which it operates. Useful for separating algorithms from object structures.
Creational Patterns
Singleton Pattern
Intent: Ensure a class has only one instance and provide a global point of access to it.
Problem: Sometimes you need exactly one instance of a class (e.g., database connection pool, thread pool, cache, configuration manager). Creating multiple instances wastes resources and can cause inconsistent state.
Solution: Make the class responsible for keeping track of its sole instance. The class can ensure that no other instance can be created (by intercepting requests to create new objects) and provide a way to access the instance.
When to Use:
- Exactly one instance of a class is required
- Controlled access to the sole instance is needed
- The instance should be extensible by subclassing
Real-World Examples:
- Database connection managers
- Logger systems
- Configuration managers
- Thread pools
- Cache systems
- Device drivers
Implementation in C++ (Thread-Safe):
#include <iostream>
#include <mutex>
#include <memory>
class DatabaseConnection {
public:
// Get the singleton instance
static DatabaseConnection& getInstance() {
// C++11 guarantees thread-safe initialization of static local variables
static DatabaseConnection instance;
return instance;
}
// Delete copy constructor and assignment operator
DatabaseConnection(const DatabaseConnection&) = delete;
DatabaseConnection& operator=(const DatabaseConnection&) = delete;
void connect(const std::string& connectionString) {
std::lock_guard<std::mutex> lock(mutex_);
if (!connected_) {
std::cout << "Connecting to database: " << connectionString << std::endl;
connected_ = true;
}
}
void query(const std::string& sql) {
std::lock_guard<std::mutex> lock(mutex_);
if (connected_) {
std::cout << "Executing: " << sql << std::endl;
} else {
std::cout << "Not connected!" << std::endl;
}
}
private:
DatabaseConnection() : connected_(false) {
std::cout << "DatabaseConnection instance created" << std::endl;
}
~DatabaseConnection() {
std::cout << "DatabaseConnection instance destroyed" << std::endl;
}
bool connected_;
std::mutex mutex_;
};
// Usage
int main() {
// All these calls return the same instance
DatabaseConnection::getInstance().connect("server=localhost;db=mydb");
DatabaseConnection::getInstance().query("SELECT * FROM users");
DatabaseConnection& db1 = DatabaseConnection::getInstance();
DatabaseConnection& db2 = DatabaseConnection::getInstance();
std::cout << "Same instance? " << (&db1 == &db2 ? "Yes" : "No") << std::endl;
return 0;
}
Implementation in Python:
import threading
class DatabaseConnection:
"""Thread-safe singleton using double-checked locking"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
# Double-checked locking
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not self._initialized:
self.connected = False
self._initialized = True
print("DatabaseConnection instance created")
def connect(self, connection_string):
if not self.connected:
print(f"Connecting to database: {connection_string}")
self.connected = True
def query(self, sql):
if self.connected:
print(f"Executing: {sql}")
else:
print("Not connected!")
# Python decorator approach (cleaner)
def singleton(cls):
instances = {}
lock = threading.Lock()
def get_instance(*args, **kwargs):
if cls not in instances:
with lock:
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance
@singleton
class Logger:
def __init__(self):
self.log_file = "app.log"
print("Logger initialized")
def log(self, message):
print(f"[LOG] {message}")
# Usage
db1 = DatabaseConnection()
db2 = DatabaseConnection()
print(f"Same instance? {db1 is db2}") # True
logger1 = Logger()
logger2 = Logger()
print(f"Same logger? {logger1 is logger2}") # True
Advantages:
- Controlled access to sole instance
- Reduced memory footprint
- Permits refinement of operations and representation
- Lazy initialization possible
Disadvantages:
- Can be difficult to test (global state)
- Violates Single Responsibility Principle
- Can mask bad design (tight coupling)
- Requires special treatment in multi-threaded environments
Factory Method Pattern
Intent: Define an interface for creating an object, but let subclasses decide which class to instantiate.
Problem: A framework needs to standardize the architectural model for a range of applications, but allow for individual applications to define their own domain objects and provide for their instantiation.
Solution: Define a factory method that returns objects of a common interface. Subclasses implement the factory method to create specific product types.
When to Use:
- A class can't anticipate the class of objects it must create
- A class wants its subclasses to specify the objects it creates
- Classes delegate responsibility to one of several helper subclasses
Real-World Examples:
- GUI frameworks creating platform-specific buttons/windows
- Document editors creating different document types
- Logistics apps creating different transport types
- Database connectors for different DBMS systems
Implementation in C++:
#include <iostream>
#include <memory>
#include <string>
// Product interface
class Transport {
public:
virtual ~Transport() = default;
virtual void deliver() = 0;
virtual std::string getType() const = 0;
};
// Concrete Products
class Truck : public Transport {
public:
void deliver() override {
std::cout << "Delivering by land in a box" << std::endl;
}
std::string getType() const override {
return "Truck";
}
};
class Ship : public Transport {
public:
void deliver() override {
std::cout << "Delivering by sea in a container" << std::endl;
}
std::string getType() const override {
return "Ship";
}
};
class Airplane : public Transport {
public:
void deliver() override {
std::cout << "Delivering by air in a cargo hold" << std::endl;
}
std::string getType() const override {
return "Airplane";
}
};
// Creator (Factory)
class Logistics {
public:
virtual ~Logistics() = default;
// Factory method
virtual std::unique_ptr<Transport> createTransport() = 0;
void planDelivery() {
auto transport = createTransport();
std::cout << "Planning delivery using " << transport->getType() << std::endl;
transport->deliver();
}
};
// Concrete Creators
class RoadLogistics : public Logistics {
public:
std::unique_ptr<Transport> createTransport() override {
return std::make_unique<Truck>();
}
};
class SeaLogistics : public Logistics {
public:
std::unique_ptr<Transport> createTransport() override {
return std::make_unique<Ship>();
}
};
class AirLogistics : public Logistics {
public:
std::unique_ptr<Transport> createTransport() override {
return std::make_unique<Airplane>();
}
};
// Usage
int main() {
std::unique_ptr<Logistics> logistics;
logistics = std::make_unique<RoadLogistics>();
logistics->planDelivery();
logistics = std::make_unique<SeaLogistics>();
logistics->planDelivery();
logistics = std::make_unique<AirLogistics>();
logistics->planDelivery();
return 0;
}
Implementation in Python:
from abc import ABC, abstractmethod
from typing import Protocol
# Product interface
class Transport(ABC):
@abstractmethod
def deliver(self) -> None:
pass
@abstractmethod
def get_type(self) -> str:
pass
# Concrete Products
class Truck(Transport):
def deliver(self) -> None:
print("Delivering by land in a box")
def get_type(self) -> str:
return "Truck"
class Ship(Transport):
def deliver(self) -> None:
print("Delivering by sea in a container")
def get_type(self) -> str:
return "Ship"
class Airplane(Transport):
def deliver(self) -> None:
print("Delivering by air in a cargo hold")
def get_type(self) -> str:
return "Airplane"
# Creator (Factory)
class Logistics(ABC):
@abstractmethod
def create_transport(self) -> Transport:
"""Factory method"""
pass
def plan_delivery(self) -> None:
transport = self.create_transport()
print(f"Planning delivery using {transport.get_type()}")
transport.deliver()
# Concrete Creators
class RoadLogistics(Logistics):
def create_transport(self) -> Transport:
return Truck()
class SeaLogistics(Logistics):
def create_transport(self) -> Transport:
return Ship()
class AirLogistics(Logistics):
def create_transport(self) -> Transport:
return Airplane()
# Usage
if __name__ == "__main__":
logistics = RoadLogistics()
logistics.plan_delivery()
logistics = SeaLogistics()
logistics.plan_delivery()
logistics = AirLogistics()
logistics.plan_delivery()
Advantages:
- Avoids tight coupling between creator and concrete products
- Single Responsibility Principle: product creation code in one place
- Open/Closed Principle: introduce new product types without breaking existing code
Disadvantages:
- Code can become more complicated with many new subclasses
- Requires subclassing just to create objects
Abstract Factory Pattern
Intent: Provide an interface for creating families of related or dependent objects without specifying their concrete classes.
Problem: You need to create families of related objects that must be used together, and you want to ensure compatibility between these objects.
Solution: Declare interfaces for creating each distinct product. Then create concrete factory classes that implement these interfaces for each product variant.
When to Use:
- System should be independent of how its products are created
- System should be configured with one of multiple families of products
- Family of related product objects must be used together
- You want to provide a class library of products without revealing implementations
Real-World Examples:
- GUI toolkits with different themes (Windows, Mac, Linux)
- Cross-platform UI libraries
- Database access libraries for different DBMS
- Document converters for different formats
Implementation in C++:
#include <iostream>
#include <memory>
#include <string>
// Abstract Products
class Button {
public:
virtual ~Button() = default;
virtual void paint() = 0;
virtual std::string getStyle() const = 0;
};
class Checkbox {
public:
virtual ~Checkbox() = default;
virtual void paint() = 0;
virtual std::string getStyle() const = 0;
};
class TextField {
public:
virtual ~TextField() = default;
virtual void paint() = 0;
virtual std::string getStyle() const = 0;
};
// Windows Products
class WindowsButton : public Button {
public:
void paint() override {
std::cout << "Rendering Windows-style button" << std::endl;
}
std::string getStyle() const override {
return "Windows";
}
};
class WindowsCheckbox : public Checkbox {
public:
void paint() override {
std::cout << "Rendering Windows-style checkbox" << std::endl;
}
std::string getStyle() const override {
return "Windows";
}
};
class WindowsTextField : public TextField {
public:
void paint() override {
std::cout << "Rendering Windows-style text field" << std::endl;
}
std::string getStyle() const override {
return "Windows";
}
};
// Mac Products
class MacButton : public Button {
public:
void paint() override {
std::cout << "Rendering Mac-style button" << std::endl;
}
std::string getStyle() const override {
return "Mac";
}
};
class MacCheckbox : public Checkbox {
public:
void paint() override {
std::cout << "Rendering Mac-style checkbox" << std::endl;
}
std::string getStyle() const override {
return "Mac";
}
};
class MacTextField : public TextField {
public:
void paint() override {
std::cout << "Rendering Mac-style text field" << std::endl;
}
std::string getStyle() const override {
return "Mac";
}
};
// Abstract Factory
class GUIFactory {
public:
virtual ~GUIFactory() = default;
virtual std::unique_ptr<Button> createButton() = 0;
virtual std::unique_ptr<Checkbox> createCheckbox() = 0;
virtual std::unique_ptr<TextField> createTextField() = 0;
};
// Concrete Factories
class WindowsFactory : public GUIFactory {
public:
std::unique_ptr<Button> createButton() override {
return std::make_unique<WindowsButton>();
}
std::unique_ptr<Checkbox> createCheckbox() override {
return std::make_unique<WindowsCheckbox>();
}
std::unique_ptr<TextField> createTextField() override {
return std::make_unique<WindowsTextField>();
}
};
class MacFactory : public GUIFactory {
public:
std::unique_ptr<Button> createButton() override {
return std::make_unique<MacButton>();
}
std::unique_ptr<Checkbox> createCheckbox() override {
return std::make_unique<MacCheckbox>();
}
std::unique_ptr<TextField> createTextField() override {
return std::make_unique<MacTextField>();
}
};
// Client code
class Application {
public:
Application(std::unique_ptr<GUIFactory> factory)
: factory_(std::move(factory)) {}
void createUI() {
button_ = factory_->createButton();
checkbox_ = factory_->createCheckbox();
textField_ = factory_->createTextField();
}
void paint() {
button_->paint();
checkbox_->paint();
textField_->paint();
}
private:
std::unique_ptr<GUIFactory> factory_;
std::unique_ptr<Button> button_;
std::unique_ptr<Checkbox> checkbox_;
std::unique_ptr<TextField> textField_;
};
// Usage
int main() {
std::string osType = "Windows"; // Could be detected at runtime
std::unique_ptr<GUIFactory> factory;
if (osType == "Windows") {
factory = std::make_unique<WindowsFactory>();
} else {
factory = std::make_unique<MacFactory>();
}
Application app(std::move(factory));
app.createUI();
app.paint();
return 0;
}
Implementation in Python:
from abc import ABC, abstractmethod
# Abstract Products
class Button(ABC):
@abstractmethod
def paint(self) -> None:
pass
@abstractmethod
def get_style(self) -> str:
pass
class Checkbox(ABC):
@abstractmethod
def paint(self) -> None:
pass
@abstractmethod
def get_style(self) -> str:
pass
class TextField(ABC):
@abstractmethod
def paint(self) -> None:
pass
@abstractmethod
def get_style(self) -> str:
pass
# Windows Products
class WindowsButton(Button):
def paint(self) -> None:
print("Rendering Windows-style button")
def get_style(self) -> str:
return "Windows"
class WindowsCheckbox(Checkbox):
def paint(self) -> None:
print("Rendering Windows-style checkbox")
def get_style(self) -> str:
return "Windows"
class WindowsTextField(TextField):
def paint(self) -> None:
print("Rendering Windows-style text field")
def get_style(self) -> str:
return "Windows"
# Mac Products
class MacButton(Button):
def paint(self) -> None:
print("Rendering Mac-style button")
def get_style(self) -> str:
return "Mac"
class MacCheckbox(Checkbox):
def paint(self) -> None:
print("Rendering Mac-style checkbox")
def get_style(self) -> str:
return "Mac"
class MacTextField(TextField):
def paint(self) -> None:
print("Rendering Mac-style text field")
def get_style(self) -> str:
return "Mac"
# Abstract Factory
class GUIFactory(ABC):
@abstractmethod
def create_button(self) -> Button:
pass
@abstractmethod
def create_checkbox(self) -> Checkbox:
pass
@abstractmethod
def create_text_field(self) -> TextField:
pass
# Concrete Factories
class WindowsFactory(GUIFactory):
def create_button(self) -> Button:
return WindowsButton()
def create_checkbox(self) -> Checkbox:
return WindowsCheckbox()
def create_text_field(self) -> TextField:
return WindowsTextField()
class MacFactory(GUIFactory):
def create_button(self) -> Button:
return MacButton()
def create_checkbox(self) -> Checkbox:
return MacCheckbox()
def create_text_field(self) -> TextField:
return MacTextField()
# Client code
class Application:
def __init__(self, factory: GUIFactory):
self.factory = factory
self.button = None
self.checkbox = None
self.text_field = None
def create_ui(self) -> None:
self.button = self.factory.create_button()
self.checkbox = self.factory.create_checkbox()
self.text_field = self.factory.create_text_field()
def paint(self) -> None:
self.button.paint()
self.checkbox.paint()
self.text_field.paint()
# Usage
if __name__ == "__main__":
import platform
os_type = platform.system()
if os_type == "Windows":
factory = WindowsFactory()
else:
factory = MacFactory()
app = Application(factory)
app.create_ui()
app.paint()
Advantages:
- Ensures compatibility between products from the same family
- Avoids tight coupling between concrete products and client code
- Single Responsibility Principle: product creation in one place
- Open/Closed Principle: introduce new variants without breaking existing code
Disadvantages:
- Code becomes more complicated due to many new interfaces and classes
- Adding new product types requires extending all factories
Builder Pattern
Intent: Separate the construction of a complex object from its representation, allowing the same construction process to create different representations.
Problem: Creating complex objects with many optional components or configuration options leads to constructor pollution (too many constructor parameters) or many constructors.
Solution: Extract object construction code out of its own class and move it to separate objects called builders. The pattern organizes object construction into a set of steps.
When to Use:
- Algorithm for creating a complex object should be independent of the parts
- Construction process must allow different representations
- Object has many optional parameters (telescoping constructor problem)
Real-World Examples:
- Building complex documents (HTML, PDF)
- Creating database queries
- Building HTTP requests
- Constructing meals at restaurants
- Building cars with various options
Implementation in C++:
#include <iostream>
#include <string>
#include <vector>
#include <memory>
// Product
class Pizza {
public:
void setDough(const std::string& dough) { dough_ = dough; }
void setSauce(const std::string& sauce) { sauce_ = sauce; }
void setCheese(const std::string& cheese) { cheese_ = cheese; }
void addTopping(const std::string& topping) { toppings_.push_back(topping); }
void setSize(const std::string& size) { size_ = size; }
void setCrust(const std::string& crust) { crust_ = crust; }
void describe() const {
std::cout << "Pizza:" << std::endl;
std::cout << " Size: " << size_ << std::endl;
std::cout << " Dough: " << dough_ << std::endl;
std::cout << " Crust: " << crust_ << std::endl;
std::cout << " Sauce: " << sauce_ << std::endl;
std::cout << " Cheese: " << cheese_ << std::endl;
std::cout << " Toppings: ";
for (const auto& topping : toppings_) {
std::cout << topping << " ";
}
std::cout << std::endl;
}
private:
std::string dough_;
std::string sauce_;
std::string cheese_;
std::vector<std::string> toppings_;
std::string size_;
std::string crust_;
};
// Abstract Builder
class PizzaBuilder {
public:
virtual ~PizzaBuilder() = default;
virtual void buildDough() = 0;
virtual void buildSauce() = 0;
virtual void buildCheese() = 0;
virtual void buildToppings() = 0;
virtual void buildSize() = 0;
virtual void buildCrust() = 0;
std::unique_ptr<Pizza> getPizza() { return std::move(pizza_); }
void reset() { pizza_ = std::make_unique<Pizza>(); }
protected:
std::unique_ptr<Pizza> pizza_;
};
// Concrete Builder 1
class MargheritaPizzaBuilder : public PizzaBuilder {
public:
MargheritaPizzaBuilder() { reset(); }
void buildDough() override {
pizza_->setDough("Thin crust dough");
}
void buildSauce() override {
pizza_->setSauce("Tomato sauce");
}
void buildCheese() override {
pizza_->setCheese("Mozzarella");
}
void buildToppings() override {
pizza_->addTopping("Fresh basil");
pizza_->addTopping("Tomato slices");
}
void buildSize() override {
pizza_->setSize("Medium");
}
void buildCrust() override {
pizza_->setCrust("Regular");
}
};
// Concrete Builder 2
class PepperoniPizzaBuilder : public PizzaBuilder {
public:
PepperoniPizzaBuilder() { reset(); }
void buildDough() override {
pizza_->setDough("Thick crust dough");
}
void buildSauce() override {
pizza_->setSauce("Spicy tomato sauce");
}
void buildCheese() override {
pizza_->setCheese("Extra mozzarella");
}
void buildToppings() override {
pizza_->addTopping("Pepperoni");
pizza_->addTopping("Mushrooms");
pizza_->addTopping("Olives");
}
void buildSize() override {
pizza_->setSize("Large");
}
void buildCrust() override {
pizza_->setCrust("Stuffed");
}
};
// Director (optional but useful for complex builds)
class PizzaDirector {
public:
void setBuilder(PizzaBuilder* builder) {
builder_ = builder;
}
void makePizza() {
builder_->buildSize();
builder_->buildDough();
builder_->buildCrust();
builder_->buildSauce();
builder_->buildCheese();
builder_->buildToppings();
}
private:
PizzaBuilder* builder_;
};
// Fluent Builder Interface (Modern C++ approach)
class FluentPizzaBuilder {
public:
FluentPizzaBuilder() : pizza_(std::make_unique<Pizza>()) {}
FluentPizzaBuilder& setSize(const std::string& size) {
pizza_->setSize(size);
return *this;
}
FluentPizzaBuilder& setDough(const std::string& dough) {
pizza_->setDough(dough);
return *this;
}
FluentPizzaBuilder& setCrust(const std::string& crust) {
pizza_->setCrust(crust);
return *this;
}
FluentPizzaBuilder& setSauce(const std::string& sauce) {
pizza_->setSauce(sauce);
return *this;
}
FluentPizzaBuilder& setCheese(const std::string& cheese) {
pizza_->setCheese(cheese);
return *this;
}
FluentPizzaBuilder& addTopping(const std::string& topping) {
pizza_->addTopping(topping);
return *this;
}
std::unique_ptr<Pizza> build() {
return std::move(pizza_);
}
private:
std::unique_ptr<Pizza> pizza_;
};
// Usage
int main() {
// Traditional approach with director
PizzaDirector director;
MargheritaPizzaBuilder margheritaBuilder;
director.setBuilder(&margheritaBuilder);
director.makePizza();
auto margherita = margheritaBuilder.getPizza();
margherita->describe();
std::cout << "\n---\n\n";
PepperoniPizzaBuilder pepperoniBuilder;
director.setBuilder(&pepperoniBuilder);
director.makePizza();
auto pepperoni = pepperoniBuilder.getPizza();
pepperoni->describe();
std::cout << "\n---\n\n";
// Fluent interface approach
auto customPizza = FluentPizzaBuilder()
.setSize("Extra Large")
.setDough("Whole wheat")
.setCrust("Thin")
.setSauce("BBQ sauce")
.setCheese("Cheddar")
.addTopping("Chicken")
.addTopping("Onions")
.addTopping("Peppers")
.build();
customPizza->describe();
return 0;
}
Implementation in Python:
from typing import List
from abc import ABC, abstractmethod
# Product
class Pizza:
def __init__(self):
self.dough = ""
self.sauce = ""
self.cheese = ""
self.toppings: List[str] = []
self.size = ""
self.crust = ""
def describe(self) -> None:
print("Pizza:")
print(f" Size: {self.size}")
print(f" Dough: {self.dough}")
print(f" Crust: {self.crust}")
print(f" Sauce: {self.sauce}")
print(f" Cheese: {self.cheese}")
print(f" Toppings: {', '.join(self.toppings)}")
# Abstract Builder
class PizzaBuilder(ABC):
def __init__(self):
self.reset()
def reset(self) -> None:
self._pizza = Pizza()
@abstractmethod
def build_dough(self) -> None:
pass
@abstractmethod
def build_sauce(self) -> None:
pass
@abstractmethod
def build_cheese(self) -> None:
pass
@abstractmethod
def build_toppings(self) -> None:
pass
@abstractmethod
def build_size(self) -> None:
pass
@abstractmethod
def build_crust(self) -> None:
pass
def get_pizza(self) -> Pizza:
pizza = self._pizza
self.reset()
return pizza
# Concrete Builders
class MargheritaPizzaBuilder(PizzaBuilder):
def build_dough(self) -> None:
self._pizza.dough = "Thin crust dough"
def build_sauce(self) -> None:
self._pizza.sauce = "Tomato sauce"
def build_cheese(self) -> None:
self._pizza.cheese = "Mozzarella"
def build_toppings(self) -> None:
self._pizza.toppings = ["Fresh basil", "Tomato slices"]
def build_size(self) -> None:
self._pizza.size = "Medium"
def build_crust(self) -> None:
self._pizza.crust = "Regular"
class PepperoniPizzaBuilder(PizzaBuilder):
def build_dough(self) -> None:
self._pizza.dough = "Thick crust dough"
def build_sauce(self) -> None:
self._pizza.sauce = "Spicy tomato sauce"
def build_cheese(self) -> None:
self._pizza.cheese = "Extra mozzarella"
def build_toppings(self) -> None:
self._pizza.toppings = ["Pepperoni", "Mushrooms", "Olives"]
def build_size(self) -> None:
self._pizza.size = "Large"
def build_crust(self) -> None:
self._pizza.crust = "Stuffed"
# Director
class PizzaDirector:
def __init__(self, builder: PizzaBuilder = None):
self._builder = builder
def set_builder(self, builder: PizzaBuilder) -> None:
self._builder = builder
def make_pizza(self) -> None:
self._builder.build_size()
self._builder.build_dough()
self._builder.build_crust()
self._builder.build_sauce()
self._builder.build_cheese()
self._builder.build_toppings()
# Fluent Builder (Pythonic approach)
class FluentPizzaBuilder:
def __init__(self):
self._pizza = Pizza()
def set_size(self, size: str):
self._pizza.size = size
return self
def set_dough(self, dough: str):
self._pizza.dough = dough
return self
def set_crust(self, crust: str):
self._pizza.crust = crust
return self
def set_sauce(self, sauce: str):
self._pizza.sauce = sauce
return self
def set_cheese(self, cheese: str):
self._pizza.cheese = cheese
return self
def add_topping(self, topping: str):
self._pizza.toppings.append(topping)
return self
def build(self) -> Pizza:
return self._pizza
# Usage
if __name__ == "__main__":
# Traditional approach with director
director = PizzaDirector()
margherita_builder = MargheritaPizzaBuilder()
director.set_builder(margherita_builder)
director.make_pizza()
margherita = margherita_builder.get_pizza()
margherita.describe()
print("\n---\n")
pepperoni_builder = PepperoniPizzaBuilder()
director.set_builder(pepperoni_builder)
director.make_pizza()
pepperoni = pepperoni_builder.get_pizza()
pepperoni.describe()
print("\n---\n")
# Fluent interface approach
custom_pizza = (FluentPizzaBuilder()
.set_size("Extra Large")
.set_dough("Whole wheat")
.set_crust("Thin")
.set_sauce("BBQ sauce")
.set_cheese("Cheddar")
.add_topping("Chicken")
.add_topping("Onions")
.add_topping("Peppers")
.build())
custom_pizza.describe()
Advantages:
- Allows construction of complex objects step by step
- Can reuse same construction code for different representations
- Single Responsibility Principle: isolates complex construction code
- Telescoping constructor problem solved
Disadvantages:
- Overall complexity increases (many new classes)
- Clients are tied to concrete builder classes
Prototype Pattern
Intent: Specify the kinds of objects to create using a prototypical instance, and create new objects by copying this prototype.
Problem: Creating objects is expensive (database queries, network calls, complex initialization), and you need many similar objects.
Solution: Delegate the cloning process to the actual objects being cloned. Declare a common interface for all objects that support cloning.
When to Use:
- Object creation is expensive
- Avoid subclasses of object creator (Factory Method alternative)
- Number of possible object states is limited
- Classes to instantiate are specified at runtime
Real-World Examples:
- Cell mitosis in biology
- Copying documents/files
- Cloning game objects with different skins
- Creating test data
- Undo/redo operations
Implementation in C++:
#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
// Prototype interface
class Shape {
public:
virtual ~Shape() = default;
virtual std::unique_ptr<Shape> clone() const = 0;
virtual void draw() const = 0;
virtual std::string getType() const = 0;
// Common properties
int x_, y_;
std::string color_;
protected:
Shape() : x_(0), y_(0), color_("black") {}
Shape(int x, int y, const std::string& color)
: x_(x), y_(y), color_(color) {}
};
// Concrete Prototype 1
class Circle : public Shape {
public:
int radius_;
Circle() : Shape(), radius_(10) {}
Circle(int x, int y, const std::string& color, int radius)
: Shape(x, y, color), radius_(radius) {}
// Copy constructor
Circle(const Circle& other)
: Shape(other.x_, other.y_, other.color_), radius_(other.radius_) {
std::cout << "Circle copied" << std::endl;
}
std::unique_ptr<Shape> clone() const override {
return std::make_unique<Circle>(*this);
}
void draw() const override {
std::cout << "Circle at (" << x_ << "," << y_ << ") "
<< "with radius " << radius_
<< " and color " << color_ << std::endl;
}
std::string getType() const override {
return "Circle";
}
};
// Concrete Prototype 2
class Rectangle : public Shape {
public:
int width_, height_;
Rectangle() : Shape(), width_(20), height_(10) {}
Rectangle(int x, int y, const std::string& color, int width, int height)
: Shape(x, y, color), width_(width), height_(height) {}
// Copy constructor
Rectangle(const Rectangle& other)
: Shape(other.x_, other.y_, other.color_),
width_(other.width_), height_(other.height_) {
std::cout << "Rectangle copied" << std::endl;
}
std::unique_ptr<Shape> clone() const override {
return std::make_unique<Rectangle>(*this);
}
void draw() const override {
std::cout << "Rectangle at (" << x_ << "," << y_ << ") "
<< "with size " << width_ << "x" << height_
<< " and color " << color_ << std::endl;
}
std::string getType() const override {
return "Rectangle";
}
};
// Prototype Registry (Prototype Manager)
class ShapeCache {
public:
static ShapeCache& getInstance() {
static ShapeCache instance;
return instance;
}
void loadCache() {
auto circle = std::make_unique<Circle>(0, 0, "red", 15);
prototypes_["red_circle"] = std::move(circle);
auto rectangle = std::make_unique<Rectangle>(0, 0, "blue", 30, 20);
prototypes_["blue_rectangle"] = std::move(rectangle);
auto smallCircle = std::make_unique<Circle>(0, 0, "green", 5);
prototypes_["small_green_circle"] = std::move(smallCircle);
}
std::unique_ptr<Shape> getShape(const std::string& type) {
auto it = prototypes_.find(type);
if (it != prototypes_.end()) {
return it->second->clone();
}
return nullptr;
}
void addShape(const std::string& key, std::unique_ptr<Shape> shape) {
prototypes_[key] = std::move(shape);
}
private:
ShapeCache() = default;
std::unordered_map<std::string, std::unique_ptr<Shape>> prototypes_;
};
// Usage
int main() {
// Load predefined prototypes
ShapeCache::getInstance().loadCache();
// Clone shapes from cache
auto shape1 = ShapeCache::getInstance().getShape("red_circle");
shape1->x_ = 10;
shape1->y_ = 20;
shape1->draw();
auto shape2 = ShapeCache::getInstance().getShape("red_circle");
shape2->x_ = 50;
shape2->y_ = 60;
shape2->draw();
auto shape3 = ShapeCache::getInstance().getShape("blue_rectangle");
shape3->x_ = 100;
shape3->y_ = 100;
shape3->draw();
// Add custom prototype
auto customCircle = std::make_unique<Circle>(0, 0, "yellow", 25);
ShapeCache::getInstance().addShape("custom_yellow", std::move(customCircle));
auto shape4 = ShapeCache::getInstance().getShape("custom_yellow");
shape4->draw();
return 0;
}
Implementation in Python:
import copy
from abc import ABC, abstractmethod
from typing import Dict
# Prototype interface
class Shape(ABC):
def __init__(self, x: int = 0, y: int = 0, color: str = "black"):
self.x = x
self.y = y
self.color = color
@abstractmethod
def clone(self):
"""Return a deep copy of the object"""
pass
@abstractmethod
def draw(self) -> None:
pass
@abstractmethod
def get_type(self) -> str:
pass
# Concrete Prototype 1
class Circle(Shape):
def __init__(self, x: int = 0, y: int = 0, color: str = "black", radius: int = 10):
super().__init__(x, y, color)
self.radius = radius
def clone(self):
"""Deep copy using copy module"""
print("Circle copied")
return copy.deepcopy(self)
def draw(self) -> None:
print(f"Circle at ({self.x},{self.y}) with radius {self.radius} and color {self.color}")
def get_type(self) -> str:
return "Circle"
# Concrete Prototype 2
class Rectangle(Shape):
def __init__(self, x: int = 0, y: int = 0, color: str = "black",
width: int = 20, height: int = 10):
super().__init__(x, y, color)
self.width = width
self.height = height
def clone(self):
"""Deep copy using copy module"""
print("Rectangle copied")
return copy.deepcopy(self)
def draw(self) -> None:
print(f"Rectangle at ({self.x},{self.y}) with size {self.width}x{self.height} and color {self.color}")
def get_type(self) -> str:
return "Rectangle"
# Prototype Registry
class ShapeCache:
_instance = None
_prototypes: Dict[str, Shape] = {}
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def load_cache(self) -> None:
"""Load predefined prototypes"""
self._prototypes["red_circle"] = Circle(0, 0, "red", 15)
self._prototypes["blue_rectangle"] = Rectangle(0, 0, "blue", 30, 20)
self._prototypes["small_green_circle"] = Circle(0, 0, "green", 5)
def get_shape(self, shape_type: str) -> Shape:
"""Clone a shape from the cache"""
prototype = self._prototypes.get(shape_type)
if prototype:
return prototype.clone()
raise ValueError(f"Shape type '{shape_type}' not found in cache")
def add_shape(self, key: str, shape: Shape) -> None:
"""Add a new prototype to the cache"""
self._prototypes[key] = shape
# Usage
if __name__ == "__main__":
# Load predefined prototypes
cache = ShapeCache()
cache.load_cache()
# Clone shapes from cache
shape1 = cache.get_shape("red_circle")
shape1.x, shape1.y = 10, 20
shape1.draw()
shape2 = cache.get_shape("red_circle")
shape2.x, shape2.y = 50, 60
shape2.draw()
shape3 = cache.get_shape("blue_rectangle")
shape3.x, shape3.y = 100, 100
shape3.draw()
# Add custom prototype
custom_circle = Circle(0, 0, "yellow", 25)
cache.add_shape("custom_yellow", custom_circle)
shape4 = cache.get_shape("custom_yellow")
shape4.draw()
Advantages:
- Reduces cost of creating complex objects
- Hides complexity of creating new instances
- Allows adding/removing products at runtime
- Configures application with classes dynamically
Disadvantages:
- Cloning complex objects with circular references can be tricky
- Deep vs shallow copy considerations
Related Patterns:
- Often used with Composite and Decorator patterns
- Designs that use Factory Method can use Prototype instead
Structural Patterns
Adapter Pattern
Intent: Convert the interface of a class into another interface that clients expect. Adapter lets classes work together that couldn't otherwise because of incompatible interfaces.
Problem: You want to use an existing class, but its interface doesn't match the one you need. You can't modify the existing class (third-party library, legacy code, or you want to keep it unchanged).
Solution: Create an adapter class that wraps the incompatible object and translates calls from the expected interface to the adaptee's interface. There are two main approaches: class adapter (using multiple inheritance) and object adapter (using composition).
When to Use:
- You want to use an existing class with an incompatible interface
- You need to create a reusable class that cooperates with unrelated classes
- You need to use several existing subclasses, but it's impractical to adapt their interface by subclassing each one (use object adapter)
- Integrating legacy code with new systems
- Working with third-party libraries
Real-World Examples:
- Power adapters (110V to 220V conversion)
- Card readers for different memory card formats
- Media player supporting multiple audio formats
- Database drivers adapting different database APIs
- XML to JSON converters
- Legacy system integration
Implementation in C++ (Object Adapter):
#include <iostream>
#include <memory>
#include <string>
#include <cmath>
// Target interface - What the client expects
class MediaPlayer {
public:
virtual ~MediaPlayer() = default;
virtual void play(const std::string& audioType, const std::string& fileName) = 0;
};
// Adaptee 1 - Advanced MP4 player with incompatible interface
class AdvancedMP4Player {
public:
void playMP4(const std::string& fileName) {
std::cout << "Playing MP4 file: " << fileName << std::endl;
}
};
// Adaptee 2 - VLC player with incompatible interface
class VLCPlayer {
public:
void playVLC(const std::string& fileName) {
std::cout << "Playing VLC file: " << fileName << std::endl;
}
};
// Adapter - Adapts AdvancedMP4Player and VLCPlayer to MediaPlayer interface
class MediaAdapter : public MediaPlayer {
public:
MediaAdapter(const std::string& audioType) {
if (audioType == "mp4") {
mp4Player_ = std::make_unique<AdvancedMP4Player>();
} else if (audioType == "vlc") {
vlcPlayer_ = std::make_unique<VLCPlayer>();
}
}
void play(const std::string& audioType, const std::string& fileName) override {
if (audioType == "mp4") {
mp4Player_->playMP4(fileName);
} else if (audioType == "vlc") {
vlcPlayer_->playVLC(fileName);
}
}
private:
std::unique_ptr<AdvancedMP4Player> mp4Player_;
std::unique_ptr<VLCPlayer> vlcPlayer_;
};
// Concrete implementation of target interface
class AudioPlayer : public MediaPlayer {
public:
void play(const std::string& audioType, const std::string& fileName) override {
// Built-in support for mp3
if (audioType == "mp3") {
std::cout << "Playing MP3 file: " << fileName << std::endl;
}
// Use adapter for other formats
else if (audioType == "mp4" || audioType == "vlc") {
auto adapter = std::make_unique<MediaAdapter>(audioType);
adapter->play(audioType, fileName);
} else {
std::cout << "Invalid media type: " << audioType << std::endl;
}
}
};
// Real-world example: Shape compatibility (legacy square to new rectangle interface)
class LegacyRectangle {
public:
void draw(int x1, int y1, int x2, int y2) {
std::cout << "Legacy Rectangle from (" << x1 << "," << y1
<< ") to (" << x2 << "," << y2 << ")" << std::endl;
}
};
// New shape interface
class Shape {
public:
virtual ~Shape() = default;
virtual void draw() = 0;
virtual void resize(int percentage) = 0;
};
// Adapter for legacy rectangle
class RectangleAdapter : public Shape {
public:
RectangleAdapter(int x, int y, int width, int height)
: x_(x), y_(y), width_(width), height_(height) {
legacyRect_ = std::make_unique<LegacyRectangle>();
}
void draw() override {
legacyRect_->draw(x_, y_, x_ + width_, y_ + height_);
}
void resize(int percentage) override {
width_ = static_cast<int>(width_ * percentage / 100.0);
height_ = static_cast<int>(height_ * percentage / 100.0);
std::cout << "Resized to " << width_ << "x" << height_ << std::endl;
}
private:
std::unique_ptr<LegacyRectangle> legacyRect_;
int x_, y_, width_, height_;
};
Class Adapter Example (Using Multiple Inheritance):
// Class adapter - inherits from both target and adaptee
class ClassMediaAdapter : public MediaPlayer, private AdvancedMP4Player {
public:
void play(const std::string& audioType, const std::string& fileName) override {
if (audioType == "mp4") {
playMP4(fileName); // Direct call to inherited method
}
}
};
// Usage
int main() {
// Object adapter example
AudioPlayer player;
player.play("mp3", "song.mp3");
player.play("mp4", "video.mp4");
player.play("vlc", "movie.vlc");
player.play("avi", "movie.avi");
std::cout << "\n---\n\n";
// Shape adapter example
RectangleAdapter rect(10, 20, 100, 50);
rect.draw();
rect.resize(150);
rect.draw();
return 0;
}
Implementation in Python:
from abc import ABC, abstractmethod
# Target interface
class MediaPlayer(ABC):
@abstractmethod
def play(self, audio_type: str, file_name: str) -> None:
pass
# Adaptee 1
class AdvancedMP4Player:
def play_mp4(self, file_name: str) -> None:
print(f"Playing MP4 file: {file_name}")
# Adaptee 2
class VLCPlayer:
def play_vlc(self, file_name: str) -> None:
print(f"Playing VLC file: {file_name}")
# Adapter
class MediaAdapter(MediaPlayer):
def __init__(self, audio_type: str):
self.audio_type = audio_type
if audio_type == "mp4":
self.advanced_player = AdvancedMP4Player()
elif audio_type == "vlc":
self.advanced_player = VLCPlayer()
def play(self, audio_type: str, file_name: str) -> None:
if audio_type == "mp4":
self.advanced_player.play_mp4(file_name)
elif audio_type == "vlc":
self.advanced_player.play_vlc(file_name)
# Concrete target
class AudioPlayer(MediaPlayer):
def play(self, audio_type: str, file_name: str) -> None:
# Built-in support for mp3
if audio_type == "mp3":
print(f"Playing MP3 file: {file_name}")
# Use adapter for other formats
elif audio_type in ["mp4", "vlc"]:
adapter = MediaAdapter(audio_type)
adapter.play(audio_type, file_name)
else:
print(f"Invalid media type: {audio_type}")
# Legacy system adapter example
class LegacyRectangle:
def draw(self, x1: int, y1: int, x2: int, y2: int) -> None:
print(f"Legacy Rectangle from ({x1},{y1}) to ({x2},{y2})")
class Shape(ABC):
@abstractmethod
def draw(self) -> None:
pass
@abstractmethod
def resize(self, percentage: int) -> None:
pass
class RectangleAdapter(Shape):
def __init__(self, x: int, y: int, width: int, height: int):
self.legacy_rect = LegacyRectangle()
self.x = x
self.y = y
self.width = width
self.height = height
def draw(self) -> None:
self.legacy_rect.draw(self.x, self.y, self.x + self.width, self.y + self.height)
def resize(self, percentage: int) -> None:
self.width = int(self.width * percentage / 100)
self.height = int(self.height * percentage / 100)
print(f"Resized to {self.width}x{self.height}")
# Usage
if __name__ == "__main__":
player = AudioPlayer()
player.play("mp3", "song.mp3")
player.play("mp4", "video.mp4")
player.play("vlc", "movie.vlc")
player.play("avi", "movie.avi")
print("\n---\n")
rect = RectangleAdapter(10, 20, 100, 50)
rect.draw()
rect.resize(150)
rect.draw()
Advantages:
- Single Responsibility Principle: separate interface conversion from business logic
- Open/Closed Principle: introduce new adapters without changing existing code
- Flexibility in adapting multiple incompatible interfaces
- Reuses existing functionality without modification
Disadvantages:
- Overall complexity increases due to new interfaces and classes
- Sometimes it's simpler to just change the service class to match the rest of your code
Related Patterns:
- Bridge: Separates interface from implementation (designed upfront), whereas Adapter makes existing classes work together (retrofitted)
- Decorator: Enhances object without changing interface; Adapter changes the interface
- Proxy: Provides same interface; Adapter provides different interface
Bridge Pattern
Intent: Decouple an abstraction from its implementation so that the two can vary independently.
Problem: When an abstraction can have multiple implementations and you want to avoid a permanent binding between them. Without Bridge, you end up with a combinatorial explosion of subclasses (e.g., Shape → CircleShape, SquareShape; Renderer → OpenGLRenderer, DirectXRenderer → OpenGLCircle, DirectXCircle, OpenGLSquare, DirectXSquare).
Solution: Separate the abstraction hierarchy from the implementation hierarchy. The abstraction contains a reference to the implementation and delegates the actual work to it.
When to Use:
- You want to avoid permanent binding between abstraction and implementation
- Both abstraction and implementation should be extensible by subclassing
- Changes in implementation shouldn't affect client code
- You want to share implementation among multiple objects (copy-on-write)
- You have a proliferation of classes from a coupled interface/implementation
Real-World Examples:
- Graphics rendering across different platforms (OpenGL, DirectX, Vulkan)
- Database drivers (abstract DB operations vs specific database implementations)
- GUI frameworks across operating systems
- Remote controls and devices (abstraction: remote, implementation: TV, radio, etc.)
- Payment processing across different payment gateways
Implementation in C++:
#include <iostream>
#include <memory>
#include <string>
// Implementation hierarchy
class Renderer {
public:
virtual ~Renderer() = default;
virtual void renderCircle(float radius) = 0;
virtual void renderSquare(float side) = 0;
virtual std::string getName() const = 0;
};
class OpenGLRenderer : public Renderer {
public:
void renderCircle(float radius) override {
std::cout << "[OpenGL] Drawing circle with radius " << radius << std::endl;
}
void renderSquare(float side) override {
std::cout << "[OpenGL] Drawing square with side " << side << std::endl;
}
std::string getName() const override {
return "OpenGL";
}
};
class DirectXRenderer : public Renderer {
public:
void renderCircle(float radius) override {
std::cout << "[DirectX] Rendering circle with radius " << radius << std::endl;
}
void renderSquare(float side) override {
std::cout << "[DirectX] Rendering square with side " << side << std::endl;
}
std::string getName() const override {
return "DirectX";
}
};
class VulkanRenderer : public Renderer {
public:
void renderCircle(float radius) override {
std::cout << "[Vulkan] Rendering circle with radius " << radius << std::endl;
}
void renderSquare(float side) override {
std::cout << "[Vulkan] Rendering square with side " << side << std::endl;
}
std::string getName() const override {
return "Vulkan";
}
};
// Abstraction hierarchy
class Shape {
public:
virtual ~Shape() = default;
Shape(std::unique_ptr<Renderer> renderer)
: renderer_(std::move(renderer)) {}
virtual void draw() = 0;
virtual void resize(float factor) = 0;
protected:
std::unique_ptr<Renderer> renderer_;
};
class Circle : public Shape {
public:
Circle(std::unique_ptr<Renderer> renderer, float radius)
: Shape(std::move(renderer)), radius_(radius) {}
void draw() override {
std::cout << "Circle: ";
renderer_->renderCircle(radius_);
}
void resize(float factor) override {
radius_ *= factor;
std::cout << "Circle resized to radius " << radius_ << std::endl;
}
private:
float radius_;
};
class Square : public Shape {
public:
Square(std::unique_ptr<Renderer> renderer, float side)
: Shape(std::move(renderer)), side_(side) {}
void draw() override {
std::cout << "Square: ";
renderer_->renderSquare(side_);
}
void resize(float factor) override {
side_ *= factor;
std::cout << "Square resized to side " << side_ << std::endl;
}
private:
float side_;
};
// Real-world example: Remote control and devices
class Device {
public:
virtual ~Device() = default;
virtual void powerOn() = 0;
virtual void powerOff() = 0;
virtual void setVolume(int volume) = 0;
virtual void setChannel(int channel) = 0;
};
class TV : public Device {
public:
void powerOn() override {
std::cout << "TV: Power ON" << std::endl;
}
void powerOff() override {
std::cout << "TV: Power OFF" << std::endl;
}
void setVolume(int volume) override {
std::cout << "TV: Setting volume to " << volume << std::endl;
}
void setChannel(int channel) override {
std::cout << "TV: Switching to channel " << channel << std::endl;
}
};
class Radio : public Device {
public:
void powerOn() override {
std::cout << "Radio: Power ON" << std::endl;
}
void powerOff() override {
std::cout << "Radio: Power OFF" << std::endl;
}
void setVolume(int volume) override {
std::cout << "Radio: Setting volume to " << volume << std::endl;
}
void setChannel(int channel) override {
std::cout << "Radio: Tuning to station " << channel << " MHz" << std::endl;
}
};
class RemoteControl {
public:
RemoteControl(std::shared_ptr<Device> device)
: device_(device) {}
virtual ~RemoteControl() = default;
void togglePower() {
if (isOn_) {
device_->powerOff();
isOn_ = false;
} else {
device_->powerOn();
isOn_ = true;
}
}
void volumeUp() {
volume_ = std::min(volume_ + 10, 100);
device_->setVolume(volume_);
}
void volumeDown() {
volume_ = std::max(volume_ - 10, 0);
device_->setVolume(volume_);
}
void channelUp() {
channel_++;
device_->setChannel(channel_);
}
void channelDown() {
channel_--;
device_->setChannel(channel_);
}
protected:
std::shared_ptr<Device> device_;
bool isOn_ = false;
int volume_ = 50;
int channel_ = 1;
};
class AdvancedRemoteControl : public RemoteControl {
public:
using RemoteControl::RemoteControl;
void mute() {
device_->setVolume(0);
std::cout << "Device muted" << std::endl;
}
};
// Usage
int main() {
// Bridge pattern with shapes and renderers
auto circle1 = std::make_unique<Circle>(std::make_unique<OpenGLRenderer>(), 5.0f);
circle1->draw();
circle1->resize(1.5f);
circle1->draw();
std::cout << "\n";
auto square1 = std::make_unique<Square>(std::make_unique<DirectXRenderer>(), 10.0f);
square1->draw();
std::cout << "\n";
auto circle2 = std::make_unique<Circle>(std::make_unique<VulkanRenderer>(), 7.0f);
circle2->draw();
std::cout << "\n---\n\n";
// Remote control example
auto tv = std::make_shared<TV>();
RemoteControl tvRemote(tv);
tvRemote.togglePower();
tvRemote.volumeUp();
tvRemote.channelUp();
std::cout << "\n";
auto radio = std::make_shared<Radio>();
AdvancedRemoteControl radioRemote(radio);
radioRemote.togglePower();
radioRemote.volumeUp();
radioRemote.mute();
return 0;
}
Implementation in Python:
from abc import ABC, abstractmethod
from typing import Protocol
# Implementation hierarchy
class Renderer(ABC):
@abstractmethod
def render_circle(self, radius: float) -> None:
pass
@abstractmethod
def render_square(self, side: float) -> None:
pass
@abstractmethod
def get_name(self) -> str:
pass
class OpenGLRenderer(Renderer):
def render_circle(self, radius: float) -> None:
print(f"[OpenGL] Drawing circle with radius {radius}")
def render_square(self, side: float) -> None:
print(f"[OpenGL] Drawing square with side {side}")
def get_name(self) -> str:
return "OpenGL"
class DirectXRenderer(Renderer):
def render_circle(self, radius: float) -> None:
print(f"[DirectX] Rendering circle with radius {radius}")
def render_square(self, side: float) -> None:
print(f"[DirectX] Rendering square with side {side}")
def get_name(self) -> str:
return "DirectX"
# Abstraction hierarchy
class Shape(ABC):
def __init__(self, renderer: Renderer):
self.renderer = renderer
@abstractmethod
def draw(self) -> None:
pass
@abstractmethod
def resize(self, factor: float) -> None:
pass
class Circle(Shape):
def __init__(self, renderer: Renderer, radius: float):
super().__init__(renderer)
self.radius = radius
def draw(self) -> None:
print("Circle: ", end="")
self.renderer.render_circle(self.radius)
def resize(self, factor: float) -> None:
self.radius *= factor
print(f"Circle resized to radius {self.radius}")
class Square(Shape):
def __init__(self, renderer: Renderer, side: float):
super().__init__(renderer)
self.side = side
def draw(self) -> None:
print("Square: ", end="")
self.renderer.render_square(self.side)
def resize(self, factor: float) -> None:
self.side *= factor
print(f"Square resized to side {self.side}")
# Device example
class Device(ABC):
@abstractmethod
def power_on(self) -> None:
pass
@abstractmethod
def power_off(self) -> None:
pass
@abstractmethod
def set_volume(self, volume: int) -> None:
pass
@abstractmethod
def set_channel(self, channel: int) -> None:
pass
class TV(Device):
def power_on(self) -> None:
print("TV: Power ON")
def power_off(self) -> None:
print("TV: Power OFF")
def set_volume(self, volume: int) -> None:
print(f"TV: Setting volume to {volume}")
def set_channel(self, channel: int) -> None:
print(f"TV: Switching to channel {channel}")
class RemoteControl:
def __init__(self, device: Device):
self.device = device
self.is_on = False
self.volume = 50
self.channel = 1
def toggle_power(self) -> None:
if self.is_on:
self.device.power_off()
self.is_on = False
else:
self.device.power_on()
self.is_on = True
def volume_up(self) -> None:
self.volume = min(self.volume + 10, 100)
self.device.set_volume(self.volume)
def channel_up(self) -> None:
self.channel += 1
self.device.set_channel(self.channel)
# Usage
if __name__ == "__main__":
circle1 = Circle(OpenGLRenderer(), 5.0)
circle1.draw()
circle1.resize(1.5)
circle1.draw()
print()
square1 = Square(DirectXRenderer(), 10.0)
square1.draw()
print("\n---\n")
tv = TV()
remote = RemoteControl(tv)
remote.toggle_power()
remote.volume_up()
remote.channel_up()
Advantages:
- Decouples interface from implementation
- Improves extensibility (extend abstraction and implementation independently)
- Hides implementation details from client
- Allows switching implementations at runtime
- Reduces number of subclasses in hierarchies
Disadvantages:
- Increases complexity with additional layers of indirection
- Can be harder to understand initially
Related Patterns:
- Abstract Factory: Can create and configure a particular Bridge
- Adapter: Makes unrelated classes work together (retrofitted); Bridge separates abstraction from implementation (designed upfront)
Composite Pattern
Intent: Compose objects into tree structures to represent part-whole hierarchies. Composite lets clients treat individual objects and compositions of objects uniformly.
Problem: You need to represent a hierarchy of objects where individual objects and groups of objects should be treated uniformly. Without Composite, clients must differentiate between leaf nodes and branches.
Solution: Define a common interface for both simple (leaf) and complex (composite) objects. Composite objects delegate operations to their children.
When to Use:
- You want to represent part-whole hierarchies
- You want clients to ignore the difference between compositions and individual objects
- You have tree-structured data (file systems, GUI components, organization charts)
- You need recursive composition
Real-World Examples:
- File systems (files and directories)
- GUI component hierarchies (panels containing buttons, labels, other panels)
- Organization charts (departments containing employees and sub-departments)
- Graphics scenes (shapes containing other shapes)
- Menu systems (menus containing menu items and submenus)
Implementation in C++:
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include <algorithm>
// Component - Common interface for leaves and composites
class FileSystemComponent {
public:
virtual ~FileSystemComponent() = default;
virtual std::string getName() const = 0;
virtual int getSize() const = 0;
virtual void display(int depth = 0) const = 0;
// Composite methods (default implementations)
virtual void add(std::shared_ptr<FileSystemComponent> component) {
throw std::runtime_error("Operation not supported");
}
virtual void remove(std::shared_ptr<FileSystemComponent> component) {
throw std::runtime_error("Operation not supported");
}
virtual std::shared_ptr<FileSystemComponent> getChild(int index) {
throw std::runtime_error("Operation not supported");
}
};
// Leaf - File
class File : public FileSystemComponent {
public:
File(const std::string& name, int size)
: name_(name), size_(size) {}
std::string getName() const override {
return name_;
}
int getSize() const override {
return size_;
}
void display(int depth = 0) const override {
std::string indent(depth * 2, ' ');
std::cout << indent << "📄 " << name_ << " (" << size_ << " KB)" << std::endl;
}
private:
std::string name_;
int size_;
};
// Composite - Directory
class Directory : public FileSystemComponent {
public:
Directory(const std::string& name)
: name_(name) {}
std::string getName() const override {
return name_;
}
int getSize() const override {
int totalSize = 0;
for (const auto& child : children_) {
totalSize += child->getSize();
}
return totalSize;
}
void display(int depth = 0) const override {
std::string indent(depth * 2, ' ');
std::cout << indent << "📁 " << name_ << " (" << getSize() << " KB total)" << std::endl;
for (const auto& child : children_) {
child->display(depth + 1);
}
}
void add(std::shared_ptr<FileSystemComponent> component) override {
children_.push_back(component);
}
void remove(std::shared_ptr<FileSystemComponent> component) override {
auto it = std::find(children_.begin(), children_.end(), component);
if (it != children_.end()) {
children_.erase(it);
}
}
std::shared_ptr<FileSystemComponent> getChild(int index) override {
if (index >= 0 && index < children_.size()) {
return children_[index];
}
return nullptr;
}
private:
std::string name_;
std::vector<std::shared_ptr<FileSystemComponent>> children_;
};
// Another example: Graphics
class Graphic {
public:
virtual ~Graphic() = default;
virtual void draw() const = 0;
virtual void move(int x, int y) = 0;
};
class Circle : public Graphic {
public:
Circle(int x, int y, int radius)
: x_(x), y_(y), radius_(radius) {}
void draw() const override {
std::cout << "Circle at (" << x_ << "," << y_ << ") with radius " << radius_ << std::endl;
}
void move(int x, int y) override {
x_ += x;
y_ += y;
}
private:
int x_, y_, radius_;
};
class Rectangle : public Graphic {
public:
Rectangle(int x, int y, int width, int height)
: x_(x), y_(y), width_(width), height_(height) {}
void draw() const override {
std::cout << "Rectangle at (" << x_ << "," << y_ << ") "
<< width_ << "x" << height_ << std::endl;
}
void move(int x, int y) override {
x_ += x;
y_ += y;
}
private:
int x_, y_, width_, height_;
};
class CompositeGraphic : public Graphic {
public:
void draw() const override {
std::cout << "Composite graphic containing:" << std::endl;
for (const auto& graphic : graphics_) {
graphic->draw();
}
}
void move(int x, int y) override {
for (auto& graphic : graphics_) {
graphic->move(x, y);
}
}
void add(std::shared_ptr<Graphic> graphic) {
graphics_.push_back(graphic);
}
void remove(std::shared_ptr<Graphic> graphic) {
auto it = std::find(graphics_.begin(), graphics_.end(), graphic);
if (it != graphics_.end()) {
graphics_.erase(it);
}
}
private:
std::vector<std::shared_ptr<Graphic>> graphics_;
};
// Usage
int main() {
// File system example
auto root = std::make_shared<Directory>("root");
auto home = std::make_shared<Directory>("home");
auto documents = std::make_shared<Directory>("documents");
auto file1 = std::make_shared<File>("resume.pdf", 150);
auto file2 = std::make_shared<File>("photo.jpg", 2500);
auto file3 = std::make_shared<File>("notes.txt", 45);
documents->add(file1);
documents->add(file3);
home->add(documents);
home->add(file2);
auto usr = std::make_shared<Directory>("usr");
auto bin = std::make_shared<Directory>("bin");
auto lib = std::make_shared<Directory>("lib");
auto file4 = std::make_shared<File>("bash", 1200);
auto file5 = std::make_shared<File>("python", 4500);
bin->add(file4);
bin->add(file5);
usr->add(bin);
usr->add(lib);
root->add(home);
root->add(usr);
root->display();
std::cout << "\n---\n\n";
// Graphics example
auto allGraphics = std::make_shared<CompositeGraphic>();
auto circle1 = std::make_shared<Circle>(10, 10, 5);
auto circle2 = std::make_shared<Circle>(20, 20, 10);
auto rect1 = std::make_shared<Rectangle>(5, 5, 15, 20);
auto group1 = std::make_shared<CompositeGraphic>();
group1->add(circle1);
group1->add(circle2);
allGraphics->add(group1);
allGraphics->add(rect1);
std::cout << "Drawing all graphics:" << std::endl;
allGraphics->draw();
std::cout << "\nMoving all graphics by (5, 5):" << std::endl;
allGraphics->move(5, 5);
allGraphics->draw();
return 0;
}
Implementation in Python:
from abc import ABC, abstractmethod
from typing import List
# Component
class FileSystemComponent(ABC):
@abstractmethod
def get_name(self) -> str:
pass
@abstractmethod
def get_size(self) -> int:
pass
@abstractmethod
def display(self, depth: int = 0) -> None:
pass
def add(self, component: 'FileSystemComponent') -> None:
raise NotImplementedError("Operation not supported")
def remove(self, component: 'FileSystemComponent') -> None:
raise NotImplementedError("Operation not supported")
# Leaf
class File(FileSystemComponent):
def __init__(self, name: str, size: int):
self._name = name
self._size = size
def get_name(self) -> str:
return self._name
def get_size(self) -> int:
return self._size
def display(self, depth: int = 0) -> None:
indent = " " * depth
print(f"{indent}📄 {self._name} ({self._size} KB)")
# Composite
class Directory(FileSystemComponent):
def __init__(self, name: str):
self._name = name
self._children: List[FileSystemComponent] = []
def get_name(self) -> str:
return self._name
def get_size(self) -> int:
return sum(child.get_size() for child in self._children)
def display(self, depth: int = 0) -> None:
indent = " " * depth
print(f"{indent}📁 {self._name} ({self.get_size()} KB total)")
for child in self._children:
child.display(depth + 1)
def add(self, component: FileSystemComponent) -> None:
self._children.append(component)
def remove(self, component: FileSystemComponent) -> None:
self._children.remove(component)
# Graphics example
class Graphic(ABC):
@abstractmethod
def draw(self) -> None:
pass
@abstractmethod
def move(self, x: int, y: int) -> None:
pass
class Circle(Graphic):
def __init__(self, x: int, y: int, radius: int):
self.x = x
self.y = y
self.radius = radius
def draw(self) -> None:
print(f"Circle at ({self.x},{self.y}) with radius {self.radius}")
def move(self, x: int, y: int) -> None:
self.x += x
self.y += y
class CompositeGraphic(Graphic):
def __init__(self):
self.graphics: List[Graphic] = []
def draw(self) -> None:
print("Composite graphic containing:")
for graphic in self.graphics:
graphic.draw()
def move(self, x: int, y: int) -> None:
for graphic in self.graphics:
graphic.move(x, y)
def add(self, graphic: Graphic) -> None:
self.graphics.append(graphic)
# Usage
if __name__ == "__main__":
root = Directory("root")
home = Directory("home")
documents = Directory("documents")
file1 = File("resume.pdf", 150)
file2 = File("photo.jpg", 2500)
file3 = File("notes.txt", 45)
documents.add(file1)
documents.add(file3)
home.add(documents)
home.add(file2)
root.add(home)
root.display()
print("\n---\n")
# Graphics
all_graphics = CompositeGraphic()
circle1 = Circle(10, 10, 5)
circle2 = Circle(20, 20, 10)
group1 = CompositeGraphic()
group1.add(circle1)
group1.add(circle2)
all_graphics.add(group1)
print("Drawing all graphics:")
all_graphics.draw()
print("\nMoving all by (5, 5):")
all_graphics.move(5, 5)
all_graphics.draw()
Advantages:
- Simplifies client code (treats individual and composite objects uniformly)
- Makes it easier to add new component types
- Supports recursive composition naturally
- Open/Closed Principle: can introduce new elements without breaking existing code
Disadvantages:
- Can make design overly general
- Can be difficult to restrict components of a composite
- Type safety: hard to enforce that composite contains only certain types
Related Patterns:
- Iterator: Often used to traverse composites
- Visitor: Can apply operations across composite structure
- Decorator: Often used together with Composite
Decorator Pattern
Intent: Attach additional responsibilities to an object dynamically. Decorators provide a flexible alternative to subclassing for extending functionality.
Problem: You need to add responsibilities to individual objects without affecting other objects or using subclassing (which is static and affects all instances).
Solution: Create decorator classes that wrap the original object. Each decorator implements the same interface as the wrapped object and adds its own behavior before/after delegating to the wrapped object.
When to Use:
- You need to add responsibilities to objects dynamically and transparently
- Extension by subclassing is impractical (class is final, or leads to many subclasses)
- Responsibilities can be withdrawn
- You want to add features to objects without changing their interface
Real-World Examples:
- Coffee shop beverages with add-ons (milk, sugar, whipped cream)
- Text formatting (bold, italic, underline combinations)
- GUI components with borders, scrollbars
- Stream processing (buffered, compressed, encrypted)
- Middleware in web frameworks
Implementation in C++:
#include <iostream>
#include <memory>
#include <string>
// Component interface
class Coffee {
public:
virtual ~Coffee() = default;
virtual std::string getDescription() const = 0;
virtual double getCost() const = 0;
};
// Concrete Component
class SimpleCoffee : public Coffee {
public:
std::string getDescription() const override {
return "Simple Coffee";
}
double getCost() const override {
return 2.0;
}
};
// Base Decorator
class CoffeeDecorator : public Coffee {
public:
CoffeeDecorator(std::unique_ptr<Coffee> coffee)
: coffee_(std::move(coffee)) {}
protected:
std::unique_ptr<Coffee> coffee_;
};
// Concrete Decorators
class MilkDecorator : public CoffeeDecorator {
public:
using CoffeeDecorator::CoffeeDecorator;
std::string getDescription() const override {
return coffee_->getDescription() + ", Milk";
}
double getCost() const override {
return coffee_->getCost() + 0.5;
}
};
class SugarDecorator : public CoffeeDecorator {
public:
using CoffeeDecorator::CoffeeDecorator;
std::string getDescription() const override {
return coffee_->getDescription() + ", Sugar";
}
double getCost() const override {
return coffee_->getCost() + 0.2;
}
};
class WhippedCreamDecorator : public CoffeeDecorator {
public:
using CoffeeDecorator::CoffeeDecorator;
std::string getDescription() const override {
return coffee_->getDescription() + ", Whipped Cream";
}
double getCost() const override {
return coffee_->getCost() + 0.7;
}
};
class CaramelDecorator : public CoffeeDecorator {
public:
using CoffeeDecorator::CoffeeDecorator;
std::string getDescription() const override {
return coffee_->getDescription() + ", Caramel";
}
double getCost() const override {
return coffee_->getCost() + 0.6;
}
};
// Another example: Text formatting
class Text {
public:
virtual ~Text() = default;
virtual std::string render() const = 0;
};
class PlainText : public Text {
public:
PlainText(const std::string& content) : content_(content) {}
std::string render() const override {
return content_;
}
private:
std::string content_;
};
class TextDecorator : public Text {
public:
TextDecorator(std::unique_ptr<Text> text)
: text_(std::move(text)) {}
protected:
std::unique_ptr<Text> text_;
};
class BoldDecorator : public TextDecorator {
public:
using TextDecorator::TextDecorator;
std::string render() const override {
return "<b>" + text_->render() + "</b>";
}
};
class ItalicDecorator : public TextDecorator {
public:
using TextDecorator::TextDecorator;
std::string render() const override {
return "<i>" + text_->render() + "</i>";
}
};
class UnderlineDecorator : public TextDecorator {
public:
using TextDecorator::TextDecorator;
std::string render() const override {
return "<u>" + text_->render() + "</u>";
}
};
// Data stream example
class DataSource {
public:
virtual ~DataSource() = default;
virtual void writeData(const std::string& data) = 0;
virtual std::string readData() = 0;
};
class FileDataSource : public DataSource {
public:
FileDataSource(const std::string& filename)
: filename_(filename) {}
void writeData(const std::string& data) override {
std::cout << "Writing to file '" << filename_ << "': " << data << std::endl;
data_ = data;
}
std::string readData() override {
std::cout << "Reading from file '" << filename_ << "'" << std::endl;
return data_;
}
private:
std::string filename_;
std::string data_;
};
class DataSourceDecorator : public DataSource {
public:
DataSourceDecorator(std::unique_ptr<DataSource> source)
: wrappee_(std::move(source)) {}
protected:
std::unique_ptr<DataSource> wrappee_;
};
class EncryptionDecorator : public DataSourceDecorator {
public:
using DataSourceDecorator::DataSourceDecorator;
void writeData(const std::string& data) override {
std::string encrypted = encrypt(data);
wrappee_->writeData(encrypted);
}
std::string readData() override {
std::string data = wrappee_->readData();
return decrypt(data);
}
private:
std::string encrypt(const std::string& data) {
std::cout << "Encrypting data..." << std::endl;
return "[ENCRYPTED]" + data + "[/ENCRYPTED]";
}
std::string decrypt(const std::string& data) {
std::cout << "Decrypting data..." << std::endl;
// Simple decryption simulation
if (data.find("[ENCRYPTED]") == 0) {
return data.substr(11, data.length() - 23);
}
return data;
}
};
class CompressionDecorator : public DataSourceDecorator {
public:
using DataSourceDecorator::DataSourceDecorator;
void writeData(const std::string& data) override {
std::string compressed = compress(data);
wrappee_->writeData(compressed);
}
std::string readData() override {
std::string data = wrappee_->readData();
return decompress(data);
}
private:
std::string compress(const std::string& data) {
std::cout << "Compressing data..." << std::endl;
return "[COMPRESSED]" + data + "[/COMPRESSED]";
}
std::string decompress(const std::string& data) {
std::cout << "Decompressing data..." << std::endl;
if (data.find("[COMPRESSED]") == 0) {
return data.substr(12, data.length() - 26);
}
return data;
}
};
// Usage
int main() {
// Coffee example
std::unique_ptr<Coffee> myCoffee = std::make_unique<SimpleCoffee>();
std::cout << myCoffee->getDescription() << " - $" << myCoffee->getCost() << std::endl;
myCoffee = std::make_unique<MilkDecorator>(std::move(myCoffee));
std::cout << myCoffee->getDescription() << " - $" << myCoffee->getCost() << std::endl;
myCoffee = std::make_unique<SugarDecorator>(std::move(myCoffee));
std::cout << myCoffee->getDescription() << " - $" << myCoffee->getCost() << std::endl;
myCoffee = std::make_unique<WhippedCreamDecorator>(std::move(myCoffee));
std::cout << myCoffee->getDescription() << " - $" << myCoffee->getCost() << std::endl;
std::cout << "\n---\n\n";
// Text formatting example
auto text = std::make_unique<PlainText>("Hello World");
std::cout << text->render() << std::endl;
text = std::make_unique<BoldDecorator>(std::move(text));
std::cout << text->render() << std::endl;
text = std::make_unique<ItalicDecorator>(std::move(text));
std::cout << text->render() << std::endl;
text = std::make_unique<UnderlineDecorator>(std::move(text));
std::cout << text->render() << std::endl;
std::cout << "\n---\n\n";
// Data stream example - combining compression and encryption
auto source = std::make_unique<FileDataSource>("data.txt");
source = std::make_unique<CompressionDecorator>(std::move(source));
source = std::make_unique<EncryptionDecorator>(std::move(source));
source->writeData("Sensitive information");
std::cout << "\nReading back:" << std::endl;
std::string data = source->readData();
std::cout << "Final data: " << data << std::endl;
return 0;
}
Implementation in Python:
from abc import ABC, abstractmethod
# Component
class Coffee(ABC):
@abstractmethod
def get_description(self) -> str:
pass
@abstractmethod
def get_cost(self) -> float:
pass
# Concrete Component
class SimpleCoffee(Coffee):
def get_description(self) -> str:
return "Simple Coffee"
def get_cost(self) -> float:
return 2.0
# Base Decorator
class CoffeeDecorator(Coffee):
def __init__(self, coffee: Coffee):
self._coffee = coffee
# Concrete Decorators
class MilkDecorator(CoffeeDecorator):
def get_description(self) -> str:
return self._coffee.get_description() + ", Milk"
def get_cost(self) -> float:
return self._coffee.get_cost() + 0.5
class SugarDecorator(CoffeeDecorator):
def get_description(self) -> str:
return self._coffee.get_description() + ", Sugar"
def get_cost(self) -> float:
return self._coffee.get_cost() + 0.2
class WhippedCreamDecorator(CoffeeDecorator):
def get_description(self) -> str:
return self._coffee.get_description() + ", Whipped Cream"
def get_cost(self) -> float:
return self._coffee.get_cost() + 0.7
# Text formatting
class Text(ABC):
@abstractmethod
def render(self) -> str:
pass
class PlainText(Text):
def __init__(self, content: str):
self.content = content
def render(self) -> str:
return self.content
class TextDecorator(Text):
def __init__(self, text: Text):
self._text = text
class BoldDecorator(TextDecorator):
def render(self) -> str:
return f"<b>{self._text.render()}</b>"
class ItalicDecorator(TextDecorator):
def render(self) -> str:
return f"<i>{self._text.render()}</i>"
# Usage
if __name__ == "__main__":
coffee = SimpleCoffee()
print(f"{coffee.get_description()} - ${coffee.get_cost()}")
coffee = MilkDecorator(coffee)
print(f"{coffee.get_description()} - ${coffee.get_cost()}")
coffee = SugarDecorator(coffee)
print(f"{coffee.get_description()} - ${coffee.get_cost()}")
coffee = WhippedCreamDecorator(coffee)
print(f"{coffee.get_description()} - ${coffee.get_cost()}")
print("\n---\n")
text = PlainText("Hello World")
print(text.render())
text = BoldDecorator(text)
print(text.render())
text = ItalicDecorator(text)
print(text.render())
Advantages:
- More flexible than static inheritance
- Responsibilities can be added/removed at runtime
- Avoids feature-laden classes high up in the hierarchy
- Single Responsibility Principle: divide functionality into classes
- Open/Closed Principle: extend behavior without modifying existing code
Disadvantages:
- Can result in many small objects (complexity)
- Decorators and their component aren't identical
- Hard to remove a specific decorator from the wrapper stack
Related Patterns:
- Adapter: Changes interface; Decorator enhances responsibilities
- Composite: Decorator can be viewed as degenerate composite with only one component
- Strategy: Decorator changes object's skin; Strategy changes object's guts
Facade Pattern
Intent: Provide a unified interface to a set of interfaces in a subsystem. Facade defines a higher-level interface that makes the subsystem easier to use.
Problem: A complex subsystem with many interdependent classes requires substantial knowledge to use effectively. Clients shouldn't need to know about subsystem implementation details.
Solution: Create a facade class that provides simple methods for client interactions with the subsystem, delegating to appropriate subsystem objects.
When to Use:
- You want to provide a simple interface to a complex subsystem
- There are many dependencies between clients and implementation classes
- You want to layer your subsystems
- You need to decouple subsystem from clients and other subsystems
Real-World Examples:
- Computer startup (facade hides CPU, memory, hard drive interactions)
- Home theater system (one button to start movie: turn on projector, sound system, DVD player, etc.)
- Online shopping checkout (facade over payment, inventory, shipping systems)
- REST API wrapping multiple microservices
- Compiler facade over lexer, parser, code generator
Implementation in C++:
#include <iostream>
#include <memory>
#include <string>
// Subsystem classes - Complex components
class CPU {
public:
void freeze() {
std::cout << "CPU: Freezing processor" << std::endl;
}
void jump(long address) {
std::cout << "CPU: Jumping to address " << address << std::endl;
}
void execute() {
std::cout << "CPU: Executing instructions" << std::endl;
}
};
class Memory {
public:
void load(long position, const std::string& data) {
std::cout << "Memory: Loading data '" << data
<< "' at position " << position << std::endl;
}
};
class HardDrive {
public:
std::string read(long lba, int size) {
std::cout << "HardDrive: Reading " << size
<< " bytes from sector " << lba << std::endl;
return "BOOT_DATA";
}
};
// Facade
class ComputerFacade {
public:
ComputerFacade()
: cpu_(std::make_unique<CPU>()),
memory_(std::make_unique<Memory>()),
hardDrive_(std::make_unique<HardDrive>()) {}
void start() {
std::cout << "Computer starting up..." << std::endl;
cpu_->freeze();
memory_->load(0, hardDrive_->read(0, 1024));
cpu_->jump(0);
cpu_->execute();
std::cout << "Computer started successfully!" << std::endl;
}
private:
std::unique_ptr<CPU> cpu_;
std::unique_ptr<Memory> memory_;
std::unique_ptr<HardDrive> hardDrive_;
};
// Home Theater example
class Amplifier {
public:
void on() { std::cout << "Amplifier: ON" << std::endl; }
void setVolume(int level) {
std::cout << "Amplifier: Setting volume to " << level << std::endl;
}
void off() { std::cout << "Amplifier: OFF" << std::endl; }
};
class DvdPlayer {
public:
void on() { std::cout << "DVD Player: ON" << std::endl; }
void play(const std::string& movie) {
std::cout << "DVD Player: Playing '" << movie << "'" << std::endl;
}
void stop() { std::cout << "DVD Player: Stopped" << std::endl; }
void off() { std::cout << "DVD Player: OFF" << std::endl; }
};
class Projector {
public:
void on() { std::cout << "Projector: ON" << std::endl; }
void wideScreenMode() { std::cout << "Projector: Widescreen mode" << std::endl; }
void off() { std::cout << "Projector: OFF" << std::endl; }
};
class TheaterLights {
public:
void dim(int level) {
std::cout << "Theater Lights: Dimming to " << level << "%" << std::endl;
}
void on() { std::cout << "Theater Lights: ON" << std::endl; }
};
class Screen {
public:
void down() { std::cout << "Screen: Going down" << std::endl; }
void up() { std::cout << "Screen: Going up" << std::endl; }
};
// Home Theater Facade
class HomeTheaterFacade {
public:
HomeTheaterFacade(
std::shared_ptr<Amplifier> amp,
std::shared_ptr<DvdPlayer> dvd,
std::shared_ptr<Projector> projector,
std::shared_ptr<Screen> screen,
std::shared_ptr<TheaterLights> lights)
: amp_(amp), dvd_(dvd), projector_(projector),
screen_(screen), lights_(lights) {}
void watchMovie(const std::string& movie) {
std::cout << "\nGet ready to watch a movie..." << std::endl;
lights_->dim(10);
screen_->down();
projector_->on();
projector_->wideScreenMode();
amp_->on();
amp_->setVolume(5);
dvd_->on();
dvd_->play(movie);
}
void endMovie() {
std::cout << "\nShutting down movie theater..." << std::endl;
dvd_->stop();
dvd_->off();
amp_->off();
projector_->off();
screen_->up();
lights_->on();
}
private:
std::shared_ptr<Amplifier> amp_;
std::shared_ptr<DvdPlayer> dvd_;
std::shared_ptr<Projector> projector_;
std::shared_ptr<Screen> screen_;
std::shared_ptr<TheaterLights> lights_;
};
// Usage
int main() {
// Computer facade example
ComputerFacade computer;
computer.start();
std::cout << "\n---\n";
// Home theater facade example
auto amp = std::make_shared<Amplifier>();
auto dvd = std::make_shared<DvdPlayer>();
auto projector = std::make_shared<Projector>();
auto screen = std::make_shared<Screen>();
auto lights = std::make_shared<TheaterLights>();
HomeTheaterFacade homeTheater(amp, dvd, projector, screen, lights);
homeTheater.watchMovie("Inception");
homeTheater.endMovie();
return 0;
}
Implementation in Python:
# Subsystem classes
class CPU:
def freeze(self) -> None:
print("CPU: Freezing processor")
def jump(self, address: int) -> None:
print(f"CPU: Jumping to address {address}")
def execute(self) -> None:
print("CPU: Executing instructions")
class Memory:
def load(self, position: int, data: str) -> None:
print(f"Memory: Loading data '{data}' at position {position}")
class HardDrive:
def read(self, lba: int, size: int) -> str:
print(f"HardDrive: Reading {size} bytes from sector {lba}")
return "BOOT_DATA"
# Facade
class ComputerFacade:
def __init__(self):
self.cpu = CPU()
self.memory = Memory()
self.hard_drive = HardDrive()
def start(self) -> None:
print("Computer starting up...")
self.cpu.freeze()
self.memory.load(0, self.hard_drive.read(0, 1024))
self.cpu.jump(0)
self.cpu.execute()
print("Computer started successfully!")
# Home Theater classes
class Amplifier:
def on(self) -> None:
print("Amplifier: ON")
def set_volume(self, level: int) -> None:
print(f"Amplifier: Setting volume to {level}")
def off(self) -> None:
print("Amplifier: OFF")
class DvdPlayer:
def on(self) -> None:
print("DVD Player: ON")
def play(self, movie: str) -> None:
print(f"DVD Player: Playing '{movie}'")
def stop(self) -> None:
print("DVD Player: Stopped")
def off(self) -> None:
print("DVD Player: OFF")
class Projector:
def on(self) -> None:
print("Projector: ON")
def wide_screen_mode(self) -> None:
print("Projector: Widescreen mode")
def off(self) -> None:
print("Projector: OFF")
class HomeTheaterFacade:
def __init__(self, amp: Amplifier, dvd: DvdPlayer, projector: Projector):
self.amp = amp
self.dvd = dvd
self.projector = projector
def watch_movie(self, movie: str) -> None:
print("\nGet ready to watch a movie...")
self.projector.on()
self.projector.wide_screen_mode()
self.amp.on()
self.amp.set_volume(5)
self.dvd.on()
self.dvd.play(movie)
def end_movie(self) -> None:
print("\nShutting down movie theater...")
self.dvd.stop()
self.dvd.off()
self.amp.off()
self.projector.off()
# Usage
if __name__ == "__main__":
computer = ComputerFacade()
computer.start()
print("\n---\n")
amp = Amplifier()
dvd = DvdPlayer()
projector = Projector()
home_theater = HomeTheaterFacade(amp, dvd, projector)
home_theater.watch_movie("Inception")
home_theater.end_movie()
Advantages:
- Simplifies complex subsystems for clients
- Reduces coupling between subsystems and clients
- Promotes weak coupling between subsystems
- Provides a simple default view while allowing access to lower-level features
- Makes libraries easier to use and test
Disadvantages:
- Facade can become a god object coupled to all classes of an application
- May introduce additional abstraction layer
- Can limit functionality if not designed to expose all subsystem features
Related Patterns:
- Abstract Factory: Can be used with Facade to hide platform-specific classes
- Mediator: Similar to Facade but abstracts communication between subsystem objects (bidirectional vs. unidirectional)
- Singleton: Facade objects are often singletons
Flyweight Pattern
Intent: Use sharing to support large numbers of fine-grained objects efficiently by sharing common state.
Problem: Creating a large number of similar objects consumes too much memory. Many objects share common data that doesn't need to be duplicated.
Solution: Separate intrinsic state (shared) from extrinsic state (unique). Store intrinsic state in flyweight objects that can be shared; pass extrinsic state to flyweight methods as parameters.
When to Use:
- Application uses large numbers of objects
- Storage costs are high because of the quantity of objects
- Most object state can be made extrinsic
- Many groups of objects may be replaced by relatively few shared objects
- Application doesn't depend on object identity
Real-World Examples:
- Text editors (character objects sharing font data)
- Game development (particles, trees, grass instances)
- String interning in programming languages
- Connection pools
- Thread pools
Implementation in C++:
#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
// Flyweight - Shared character data
class CharacterStyle {
public:
CharacterStyle(const std::string& font, int size, const std::string& color)
: font_(font), size_(size), color_(color) {
std::cout << "Creating CharacterStyle: " << font << " " << size
<< "pt " << color << std::endl;
}
void display(char symbol, int row, int column) const {
std::cout << "Character '" << symbol << "' at (" << row << "," << column
<< ") - Font: " << font_ << " " << size_ << "pt " << color_
<< std::endl;
}
std::string getKey() const {
return font_ + "_" + std::to_string(size_) + "_" + color_;
}
private:
std::string font_; // Intrinsic state (shared)
int size_; // Intrinsic state (shared)
std::string color_; // Intrinsic state (shared)
};
// Flyweight Factory
class CharacterStyleFactory {
public:
std::shared_ptr<CharacterStyle> getStyle(
const std::string& font, int size, const std::string& color) {
std::string key = font + "_" + std::to_string(size) + "_" + color;
auto it = styles_.find(key);
if (it != styles_.end()) {
std::cout << "Reusing existing style: " << key << std::endl;
return it->second;
}
auto newStyle = std::make_shared<CharacterStyle>(font, size, color);
styles_[key] = newStyle;
return newStyle;
}
size_t getStyleCount() const {
return styles_.size();
}
private:
std::unordered_map<std::string, std::shared_ptr<CharacterStyle>> styles_;
};
// Context - Character with position (extrinsic state)
class Character {
public:
Character(char symbol, int row, int column,
std::shared_ptr<CharacterStyle> style)
: symbol_(symbol), row_(row), column_(column), style_(style) {}
void display() const {
style_->display(symbol_, row_, column_);
}
private:
char symbol_; // Extrinsic state (unique)
int row_, column_; // Extrinsic state (unique)
std::shared_ptr<CharacterStyle> style_; // Reference to flyweight
};
// Game example - Trees in a forest
class TreeType {
public:
TreeType(const std::string& name, const std::string& color, const std::string& texture)
: name_(name), color_(color), texture_(texture) {
std::cout << "Creating tree type: " << name << std::endl;
}
void draw(int x, int y) const {
std::cout << name_ << " tree (color: " << color_ << ", texture: " << texture_
<< ") at (" << x << "," << y << ")" << std::endl;
}
private:
std::string name_; // Intrinsic
std::string color_; // Intrinsic
std::string texture_; // Intrinsic
};
class TreeFactory {
public:
std::shared_ptr<TreeType> getTreeType(
const std::string& name, const std::string& color, const std::string& texture) {
std::string key = name + "_" + color + "_" + texture;
auto it = treeTypes_.find(key);
if (it != treeTypes_.end()) {
return it->second;
}
auto newType = std::make_shared<TreeType>(name, color, texture);
treeTypes_[key] = newType;
return newType;
}
size_t getTypeCount() const {
return treeTypes_.size();
}
private:
std::unordered_map<std::string, std::shared_ptr<TreeType>> treeTypes_;
};
class Tree {
public:
Tree(int x, int y, std::shared_ptr<TreeType> type)
: x_(x), y_(y), type_(type) {}
void draw() const {
type_->draw(x_, y_);
}
private:
int x_, y_; // Extrinsic state (unique per tree)
std::shared_ptr<TreeType> type_; // Intrinsic state (shared)
};
class Forest {
public:
void plantTree(int x, int y, const std::string& name,
const std::string& color, const std::string& texture) {
auto type = treeFactory_.getTreeType(name, color, texture);
trees_.push_back(std::make_unique<Tree>(x, y, type));
}
void draw() const {
for (const auto& tree : trees_) {
tree->draw();
}
std::cout << "\nTotal trees: " << trees_.size()
<< ", Unique tree types: " << treeFactory_.getTypeCount()
<< std::endl;
}
private:
TreeFactory treeFactory_;
std::vector<std::unique_ptr<Tree>> trees_;
};
// Usage
int main() {
// Text editor example
CharacterStyleFactory styleFactory;
auto arial12Black = styleFactory.getStyle("Arial", 12, "Black");
auto arial12Red = styleFactory.getStyle("Arial", 12, "Red");
auto arial14Black = styleFactory.getStyle("Arial", 14, "Black");
auto arial12Black2 = styleFactory.getStyle("Arial", 12, "Black"); // Reuses
std::vector<Character> document;
document.emplace_back('H', 0, 0, arial14Black);
document.emplace_back('e', 0, 1, arial12Black);
document.emplace_back('l', 0, 2, arial12Black);
document.emplace_back('l', 0, 3, arial12Red);
document.emplace_back('o', 0, 4, arial12Black);
std::cout << "\nDocument with " << document.size() << " characters:" << std::endl;
for (const auto& ch : document) {
ch.display();
}
std::cout << "\nTotal unique styles created: "
<< styleFactory.getStyleCount() << std::endl;
std::cout << "\n---\n\n";
// Forest example
Forest forest;
forest.plantTree(10, 20, "Oak", "Green", "OakTexture");
forest.plantTree(50, 30, "Pine", "DarkGreen", "PineTexture");
forest.plantTree(80, 40, "Oak", "Green", "OakTexture"); // Reuses Oak type
forest.plantTree(120, 50, "Pine", "DarkGreen", "PineTexture"); // Reuses Pine
forest.plantTree(150, 60, "Birch", "White", "BirchTexture");
forest.plantTree(200, 70, "Oak", "Green", "OakTexture"); // Reuses Oak type
std::cout << "\nDrawing forest:" << std::endl;
forest.draw();
return 0;
}
Implementation in Python:
from typing import Dict
# Flyweight
class CharacterStyle:
def __init__(self, font: str, size: int, color: str):
self.font = font
self.size = size
self.color = color
print(f"Creating CharacterStyle: {font} {size}pt {color}")
def display(self, symbol: str, row: int, column: int) -> None:
print(f"Character '{symbol}' at ({row},{column}) - "
f"Font: {self.font} {self.size}pt {self.color}")
# Flyweight Factory
class CharacterStyleFactory:
def __init__(self):
self._styles: Dict[str, CharacterStyle] = {}
def get_style(self, font: str, size: int, color: str) -> CharacterStyle:
key = f"{font}_{size}_{color}"
if key in self._styles:
print(f"Reusing existing style: {key}")
return self._styles[key]
new_style = CharacterStyle(font, size, color)
self._styles[key] = new_style
return new_style
def get_style_count(self) -> int:
return len(self._styles)
# Context
class Character:
def __init__(self, symbol: str, row: int, column: int, style: CharacterStyle):
self.symbol = symbol # Extrinsic
self.row = row # Extrinsic
self.column = column # Extrinsic
self.style = style # Intrinsic (shared)
def display(self) -> None:
self.style.display(self.symbol, self.row, self.column)
# Tree example
class TreeType:
def __init__(self, name: str, color: str, texture: str):
self.name = name
self.color = color
self.texture = texture
print(f"Creating tree type: {name}")
def draw(self, x: int, y: int) -> None:
print(f"{self.name} tree (color: {self.color}, texture: {self.texture}) "
f"at ({x},{y})")
class TreeFactory:
def __init__(self):
self._tree_types: Dict[str, TreeType] = {}
def get_tree_type(self, name: str, color: str, texture: str) -> TreeType:
key = f"{name}_{color}_{texture}"
if key in self._tree_types:
return self._tree_types[key]
new_type = TreeType(name, color, texture)
self._tree_types[key] = new_type
return new_type
def get_type_count(self) -> int:
return len(self._tree_types)
class Tree:
def __init__(self, x: int, y: int, tree_type: TreeType):
self.x = x # Extrinsic
self.y = y # Extrinsic
self.type = tree_type # Intrinsic (shared)
def draw(self) -> None:
self.type.draw(self.x, self.y)
class Forest:
def __init__(self):
self.tree_factory = TreeFactory()
self.trees = []
def plant_tree(self, x: int, y: int, name: str, color: str, texture: str) -> None:
tree_type = self.tree_factory.get_tree_type(name, color, texture)
self.trees.append(Tree(x, y, tree_type))
def draw(self) -> None:
for tree in self.trees:
tree.draw()
print(f"\nTotal trees: {len(self.trees)}, "
f"Unique tree types: {self.tree_factory.get_type_count()}")
# Usage
if __name__ == "__main__":
# Text editor
factory = CharacterStyleFactory()
arial_12_black = factory.get_style("Arial", 12, "Black")
arial_12_red = factory.get_style("Arial", 12, "Red")
arial_12_black_2 = factory.get_style("Arial", 12, "Black") # Reuses
document = [
Character('H', 0, 0, arial_12_black),
Character('e', 0, 1, arial_12_black),
Character('l', 0, 2, arial_12_red),
Character('l', 0, 3, arial_12_black),
Character('o', 0, 4, arial_12_black),
]
print(f"\nDocument with {len(document)} characters:")
for ch in document:
ch.display()
print(f"\nTotal unique styles: {factory.get_style_count()}")
print("\n---\n")
# Forest
forest = Forest()
forest.plant_tree(10, 20, "Oak", "Green", "OakTexture")
forest.plant_tree(50, 30, "Pine", "DarkGreen", "PineTexture")
forest.plant_tree(80, 40, "Oak", "Green", "OakTexture") # Reuses
forest.plant_tree(120, 50, "Birch", "White", "BirchTexture")
print("\nDrawing forest:")
forest.draw()
Advantages:
- Reduces memory consumption significantly
- Reduces total number of objects
- Can improve performance by reducing memory allocation overhead
- Centralizes state management for shared data
Disadvantages:
- More complex code (intrinsic vs. extrinsic state separation)
- Runtime costs for computing extrinsic state
- Can make design less intuitive
Related Patterns:
- Composite: Often combined with Flyweight to implement shared leaf nodes
- State and Strategy: Can be implemented as flyweights
- Singleton: Flyweight factories are often singletons
Proxy Pattern
Intent: Provide a surrogate or placeholder for another object to control access to it.
Problem: You need to control access to an object for various reasons: expensive creation, remote access, access control, logging, lazy initialization, etc.
Solution: Create a proxy object with the same interface as the real object. The proxy controls access to the real object and can perform additional operations before/after forwarding requests.
When to Use:
- Virtual Proxy: Lazy initialization of expensive objects
- Remote Proxy: Local representative for remote objects
- Protection Proxy: Access control based on permissions
- Smart Reference: Additional actions when object is accessed (reference counting, locking, lazy loading)
- Caching Proxy: Cache results of expensive operations
- Logging Proxy: Log requests before forwarding
Real-World Examples:
- Image proxies in web browsers (placeholder until loaded)
- Network proxies and VPNs
- Smart pointers in C++
- Lazy-loaded ORM entities
- Access control in security systems
Implementation in C++:
#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
// Subject interface
class Image {
public:
virtual ~Image() = default;
virtual void display() = 0;
virtual std::string getName() const = 0;
};
// Real Subject - Expensive object
class RealImage : public Image {
public:
RealImage(const std::string& filename)
: filename_(filename) {
loadFromDisk();
}
void display() override {
std::cout << "Displaying " << filename_ << std::endl;
}
std::string getName() const override {
return filename_;
}
private:
void loadFromDisk() {
std::cout << "Loading " << filename_ << " from disk (expensive operation)..."
<< std::endl;
}
std::string filename_;
};
// Virtual Proxy - Lazy initialization
class ImageProxy : public Image {
public:
ImageProxy(const std::string& filename)
: filename_(filename), realImage_(nullptr) {}
void display() override {
if (!realImage_) {
std::cout << "Proxy: First access, loading real image..." << std::endl;
realImage_ = std::make_unique<RealImage>(filename_);
}
realImage_->display();
}
std::string getName() const override {
return filename_;
}
private:
std::string filename_;
std::unique_ptr<RealImage> realImage_;
};
// Protection Proxy example
class Document {
public:
virtual ~Document() = default;
virtual void displayContent() = 0;
virtual void editContent(const std::string& newContent) = 0;
};
class RealDocument : public Document {
public:
RealDocument(const std::string& content)
: content_(content) {}
void displayContent() override {
std::cout << "Document content: " << content_ << std::endl;
}
void editContent(const std::string& newContent) override {
content_ = newContent;
std::cout << "Document updated to: " << content_ << std::endl;
}
private:
std::string content_;
};
class DocumentProxy : public Document {
public:
DocumentProxy(std::unique_ptr<RealDocument> doc, const std::string& userRole)
: document_(std::move(doc)), userRole_(userRole) {}
void displayContent() override {
std::cout << "Proxy: Checking read permissions for " << userRole_ << "..." << std::endl;
document_->displayContent();
}
void editContent(const std::string& newContent) override {
if (userRole_ == "admin" || userRole_ == "editor") {
std::cout << "Proxy: " << userRole_ << " has write permission" << std::endl;
document_->editContent(newContent);
} else {
std::cout << "Proxy: Access denied! " << userRole_
<< " doesn't have write permission" << std::endl;
}
}
private:
std::unique_ptr<RealDocument> document_;
std::string userRole_;
};
// Caching Proxy example
class DataService {
public:
virtual ~DataService() = default;
virtual std::string getData(const std::string& key) = 0;
};
class RealDataService : public DataService {
public:
std::string getData(const std::string& key) override {
std::cout << "RealDataService: Fetching '" << key
<< "' from database (expensive)..." << std::endl;
return "Data for " + key;
}
};
class CachingDataServiceProxy : public DataService {
public:
CachingDataServiceProxy(std::unique_ptr<RealDataService> service)
: service_(std::move(service)) {}
std::string getData(const std::string& key) override {
auto it = cache_.find(key);
if (it != cache_.end()) {
std::cout << "CachingProxy: Returning cached data for '" << key << "'"
<< std::endl;
return it->second;
}
std::cout << "CachingProxy: Cache miss, fetching from real service..."
<< std::endl;
std::string data = service_->getData(key);
cache_[key] = data;
return data;
}
private:
std::unique_ptr<RealDataService> service_;
std::unordered_map<std::string, std::string> cache_;
};
// Usage
int main() {
// Virtual Proxy - Lazy loading
std::cout << "=== Virtual Proxy Example ===" << std::endl;
auto image1 = std::make_unique<ImageProxy>("photo1.jpg");
auto image2 = std::make_unique<ImageProxy>("photo2.jpg");
std::cout << "\nImages created (not loaded yet)\n" << std::endl;
image1->display(); // Loads and displays
image1->display(); // Just displays (already loaded)
std::cout << "\n---\n\n";
// Protection Proxy
std::cout << "=== Protection Proxy Example ===" << std::endl;
auto adminDoc = std::make_unique<DocumentProxy>(
std::make_unique<RealDocument>("Secret Document"),
"admin"
);
auto viewerDoc = std::make_unique<DocumentProxy>(
std::make_unique<RealDocument>("Public Document"),
"viewer"
);
adminDoc->displayContent();
adminDoc->editContent("Updated Secret Document");
std::cout << std::endl;
viewerDoc->displayContent();
viewerDoc->editContent("Trying to update"); // Denied
std::cout << "\n---\n\n";
// Caching Proxy
std::cout << "=== Caching Proxy Example ===" << std::endl;
auto dataService = std::make_unique<CachingDataServiceProxy>(
std::make_unique<RealDataService>()
);
std::cout << dataService->getData("user:123") << std::endl;
std::cout << std::endl;
std::cout << dataService->getData("user:123") << std::endl; // From cache
std::cout << std::endl;
std::cout << dataService->getData("user:456") << std::endl;
return 0;
}
Implementation in Python:
from abc import ABC, abstractmethod
from typing import Dict, Optional
# Virtual Proxy
class Image(ABC):
@abstractmethod
def display(self) -> None:
pass
class RealImage(Image):
def __init__(self, filename: str):
self.filename = filename
self._load_from_disk()
def _load_from_disk(self) -> None:
print(f"Loading {self.filename} from disk (expensive operation)...")
def display(self) -> None:
print(f"Displaying {self.filename}")
class ImageProxy(Image):
def __init__(self, filename: str):
self.filename = filename
self._real_image: Optional[RealImage] = None
def display(self) -> None:
if self._real_image is None:
print("Proxy: First access, loading real image...")
self._real_image = RealImage(self.filename)
self._real_image.display()
# Protection Proxy
class Document(ABC):
@abstractmethod
def display_content(self) -> None:
pass
@abstractmethod
def edit_content(self, new_content: str) -> None:
pass
class RealDocument(Document):
def __init__(self, content: str):
self.content = content
def display_content(self) -> None:
print(f"Document content: {self.content}")
def edit_content(self, new_content: str) -> None:
self.content = new_content
print(f"Document updated to: {self.content}")
class DocumentProxy(Document):
def __init__(self, document: RealDocument, user_role: str):
self.document = document
self.user_role = user_role
def display_content(self) -> None:
print(f"Proxy: Checking read permissions for {self.user_role}...")
self.document.display_content()
def edit_content(self, new_content: str) -> None:
if self.user_role in ["admin", "editor"]:
print(f"Proxy: {self.user_role} has write permission")
self.document.edit_content(new_content)
else:
print(f"Proxy: Access denied! {self.user_role} doesn't have write permission")
# Caching Proxy
class DataService(ABC):
@abstractmethod
def get_data(self, key: str) -> str:
pass
class RealDataService(DataService):
def get_data(self, key: str) -> str:
print(f"RealDataService: Fetching '{key}' from database (expensive)...")
return f"Data for {key}"
class CachingDataServiceProxy(DataService):
def __init__(self, service: RealDataService):
self.service = service
self.cache: Dict[str, str] = {}
def get_data(self, key: str) -> str:
if key in self.cache:
print(f"CachingProxy: Returning cached data for '{key}'")
return self.cache[key]
print("CachingProxy: Cache miss, fetching from real service...")
data = self.service.get_data(key)
self.cache[key] = data
return data
# Usage
if __name__ == "__main__":
# Virtual Proxy
print("=== Virtual Proxy Example ===")
image1 = ImageProxy("photo1.jpg")
image2 = ImageProxy("photo2.jpg")
print("\nImages created (not loaded yet)\n")
image1.display() # Loads and displays
image1.display() # Just displays
print("\n---\n")
# Protection Proxy
print("=== Protection Proxy Example ===")
admin_doc = DocumentProxy(RealDocument("Secret Document"), "admin")
viewer_doc = DocumentProxy(RealDocument("Public Document"), "viewer")
admin_doc.display_content()
admin_doc.edit_content("Updated Secret Document")
print()
viewer_doc.display_content()
viewer_doc.edit_content("Trying to update") # Denied
print("\n---\n")
# Caching Proxy
print("=== Caching Proxy Example ===")
data_service = CachingDataServiceProxy(RealDataService())
print(data_service.get_data("user:123"))
print()
print(data_service.get_data("user:123")) # From cache
print()
print(data_service.get_data("user:456"))
Advantages:
- Controls access to the real object
- Can add functionality transparently (logging, caching, lazy loading)
- Open/Closed Principle: introduce new proxies without changing the service
- Can manage lifecycle of service object
Disadvantages:
- Code may become more complicated (many new classes)
- Response from service might be delayed
- Additional layer of indirection
Related Patterns:
- Adapter: Provides different interface; Proxy provides same interface
- Decorator: Similar structure but different intent (enhancement vs. access control)
- Facade: Provides simplified interface; Proxy provides same interface
Behavioral Patterns
Observer Pattern
Intent: Define a one-to-many dependency between objects so that when one object changes state, all its dependents are notified and updated automatically.
Problem: You need to maintain consistency between related objects without making them tightly coupled. When one object changes, an unknown number of other objects need to be updated.
Solution: Define Subject (publisher) and Observer (subscriber) interfaces. Subjects maintain a list of observers and notify them automatically of state changes. Observers register/unregister themselves with subjects.
When to Use:
- Change to one object requires changing others, and you don't know how many
- Object should notify other objects without knowing who they are
- Need loosely coupled event handling system
- Implementing distributed event handling (MVC, pub-sub systems)
Real-World Examples:
- Event listeners in GUI frameworks
- Model-View-Controller (MVC) architecture
- Social media notifications (followers notified of new posts)
- Stock market tickers
- RSS feeds
- Reactive programming (RxJS, ReactiveX)
Implementation in C++:
#include <iostream>
#include <vector>
#include <string>
#include <memory>
#include <algorithm>
// Observer interface
class Observer {
public:
virtual ~Observer() = default;
virtual void update(const std::string& message) = 0;
virtual std::string getName() const = 0;
};
// Subject (Observable) interface
class Subject {
public:
virtual ~Subject() = default;
virtual void attach(std::shared_ptr<Observer> observer) = 0;
virtual void detach(std::shared_ptr<Observer> observer) = 0;
virtual void notify(const std::string& message) = 0;
};
// Concrete Subject - News Agency
class NewsAgency : public Subject {
public:
void attach(std::shared_ptr<Observer> observer) override {
observers_.push_back(observer);
std::cout << "NewsAgency: Attached observer " << observer->getName() << std::endl;
}
void detach(std::shared_ptr<Observer> observer) override {
auto it = std::find(observers_.begin(), observers_.end(), observer);
if (it != observers_.end()) {
std::cout << "NewsAgency: Detached observer " << observer->getName() << std::endl;
observers_.erase(it);
}
}
void notify(const std::string& message) override {
std::cout << "\nNewsAgency: Broadcasting news..." << std::endl;
for (auto& observer : observers_) {
if (auto obs = observer.lock()) {
obs->update(message);
}
}
}
void publishNews(const std::string& news) {
news_ = news;
notify(news_);
}
private:
std::string news_;
std::vector<std::weak_ptr<Observer>> observers_;
};
// Concrete Observers
class NewsChannel : public Observer {
public:
NewsChannel(const std::string& name) : name_(name) {}
void update(const std::string& message) override {
std::cout << "NewsChannel [" << name_ << "]: Received news - " << message << std::endl;
}
std::string getName() const override {
return name_;
}
private:
std::string name_;
};
class Newspaper : public Observer {
public:
Newspaper(const std::string& name) : name_(name) {}
void update(const std::string& message) override {
std::cout << "Newspaper [" << name_ << "]: Printing news - " << message << std::endl;
headlines_.push_back(message);
}
std::string getName() const override {
return name_;
}
void printArchive() const {
std::cout << "\n" << name_ << " Archive:" << std::endl;
for (size_t i = 0; i < headlines_.size(); ++i) {
std::cout << " " << (i + 1) << ". " << headlines_[i] << std::endl;
}
}
private:
std::string name_;
std::vector<std::string> headlines_;
};
// Weather Station example
class WeatherStation {
public:
void setMeasurements(float temperature, float humidity, float pressure) {
temperature_ = temperature;
humidity_ = humidity;
pressure_ = pressure;
measurementsChanged();
}
void attach(std::shared_ptr<Observer> observer) {
observers_.push_back(observer);
}
void detach(std::shared_ptr<Observer> observer) {
auto it = std::find(observers_.begin(), observers_.end(), observer);
if (it != observers_.end()) {
observers_.erase(it);
}
}
private:
void measurementsChanged() {
std::string data = "Temp: " + std::to_string(temperature_) + "°C, " +
"Humidity: " + std::to_string(humidity_) + "%, " +
"Pressure: " + std::to_string(pressure_) + " hPa";
for (auto& observer : observers_) {
if (auto obs = observer.lock()) {
obs->update(data);
}
}
}
float temperature_ = 0.0f;
float humidity_ = 0.0f;
float pressure_ = 0.0f;
std::vector<std::weak_ptr<Observer>> observers_;
};
class WeatherDisplay : public Observer {
public:
WeatherDisplay(const std::string& name) : name_(name) {}
void update(const std::string& message) override {
std::cout << "Display [" << name_ << "]: " << message << std::endl;
}
std::string getName() const override {
return name_;
}
private:
std::string name_;
};
// Usage
int main() {
// News agency example
auto newsAgency = std::make_unique<NewsAgency>();
auto cnn = std::make_shared<NewsChannel>("CNN");
auto bbc = std::make_shared<NewsChannel>("BBC");
auto nyt = std::make_shared<Newspaper>("New York Times");
newsAgency->attach(cnn);
newsAgency->attach(bbc);
newsAgency->attach(nyt);
newsAgency->publishNews("Breaking: Major tech announcement!");
std::cout << "\nDetaching CNN..." << std::endl;
newsAgency->detach(cnn);
newsAgency->publishNews("Update: Market reaches new high");
nyt->printArchive();
std::cout << "\n---\n\n";
// Weather station example
WeatherStation station;
auto homeDisplay = std::make_shared<WeatherDisplay>("Home");
auto officeDisplay = std::make_shared<WeatherDisplay>("Office");
auto mobileDisplay = std::make_shared<WeatherDisplay>("Mobile");
station.attach(homeDisplay);
station.attach(officeDisplay);
station.attach(mobileDisplay);
std::cout << "Weather update 1:" << std::endl;
station.setMeasurements(25.5f, 65.0f, 1013.2f);
std::cout << "\nWeather update 2:" << std::endl;
station.setMeasurements(27.0f, 70.0f, 1012.8f);
return 0;
}
Implementation in Python:
from abc import ABC, abstractmethod
from typing import List
from weakref import WeakSet
# Observer interface
class Observer(ABC):
@abstractmethod
def update(self, message: str) -> None:
pass
@abstractmethod
def get_name(self) -> str:
pass
# Subject interface
class Subject(ABC):
@abstractmethod
def attach(self, observer: Observer) -> None:
pass
@abstractmethod
def detach(self, observer: Observer) -> None:
pass
@abstractmethod
def notify(self, message: str) -> None:
pass
# Concrete Subject
class NewsAgency(Subject):
def __init__(self):
self._observers: WeakSet[Observer] = WeakSet()
self._news: str = ""
def attach(self, observer: Observer) -> None:
self._observers.add(observer)
print(f"NewsAgency: Attached observer {observer.get_name()}")
def detach(self, observer: Observer) -> None:
self._observers.discard(observer)
print(f"NewsAgency: Detached observer {observer.get_name()}")
def notify(self, message: str) -> None:
print("\nNewsAgency: Broadcasting news...")
for observer in self._observers:
observer.update(message)
def publish_news(self, news: str) -> None:
self._news = news
self.notify(self._news)
# Concrete Observers
class NewsChannel(Observer):
def __init__(self, name: str):
self._name = name
def update(self, message: str) -> None:
print(f"NewsChannel [{self._name}]: Received news - {message}")
def get_name(self) -> str:
return self._name
class Newspaper(Observer):
def __init__(self, name: str):
self._name = name
self._headlines: List[str] = []
def update(self, message: str) -> None:
print(f"Newspaper [{self._name}]: Printing news - {message}")
self._headlines.append(message)
def get_name(self) -> str:
return self._name
def print_archive(self) -> None:
print(f"\n{self._name} Archive:")
for i, headline in enumerate(self._headlines, 1):
print(f" {i}. {headline}")
# Weather Station example
class WeatherStation:
def __init__(self):
self._observers: WeakSet[Observer] = WeakSet()
self._temperature: float = 0.0
self._humidity: float = 0.0
self._pressure: float = 0.0
def attach(self, observer: Observer) -> None:
self._observers.add(observer)
def detach(self, observer: Observer) -> None:
self._observers.discard(observer)
def set_measurements(self, temperature: float, humidity: float, pressure: float) -> None:
self._temperature = temperature
self._humidity = humidity
self._pressure = pressure
self._measurements_changed()
def _measurements_changed(self) -> None:
data = f"Temp: {self._temperature}°C, Humidity: {self._humidity}%, Pressure: {self._pressure} hPa"
for observer in self._observers:
observer.update(data)
class WeatherDisplay(Observer):
def __init__(self, name: str):
self._name = name
def update(self, message: str) -> None:
print(f"Display [{self._name}]: {message}")
def get_name(self) -> str:
return self._name
# Usage
if __name__ == "__main__":
# News agency example
news_agency = NewsAgency()
cnn = NewsChannel("CNN")
bbc = NewsChannel("BBC")
nyt = Newspaper("New York Times")
news_agency.attach(cnn)
news_agency.attach(bbc)
news_agency.attach(nyt)
news_agency.publish_news("Breaking: Major tech announcement!")
print("\nDetaching CNN...")
news_agency.detach(cnn)
news_agency.publish_news("Update: Market reaches new high")
nyt.print_archive()
print("\n---\n")
# Weather station
station = WeatherStation()
home_display = WeatherDisplay("Home")
office_display = WeatherDisplay("Office")
station.attach(home_display)
station.attach(office_display)
print("Weather update 1:")
station.set_measurements(25.5, 65.0, 1013.2)
print("\nWeather update 2:")
station.set_measurements(27.0, 70.0, 1012.8)
Advantages:
- Loose coupling between subject and observers
- Open/Closed Principle: add new observers without modifying subject
- Establishes relationships at runtime
- Supports broadcast communication
Disadvantages:
- Observers notified in random order
- Memory leaks if observers aren't properly detached
- Can cause unexpected updates if dependencies are complex
- Performance issues with many observers
Related Patterns:
- Mediator: Both promote loose coupling; Mediator uses centralized communication, Observer uses distributed
- Singleton: Subject often implemented as singleton
Strategy Pattern
Intent: Define a family of algorithms, encapsulate each one, and make them interchangeable. Strategy lets the algorithm vary independently from clients that use it.
Problem: You have multiple related classes that differ only in their behavior. You need to select an algorithm at runtime, or you have many conditional statements choosing between different variants of the same algorithm.
Solution: Extract algorithms into separate classes (strategies) with a common interface. Context class delegates work to a strategy object instead of implementing multiple versions of the algorithm.
When to Use:
- You have many related classes differing only in behavior
- You need different variants of an algorithm
- Algorithm uses data clients shouldn't know about
- Class has massive conditional statements selecting different behaviors
Real-World Examples:
- Payment processing (credit card, PayPal, cryptocurrency)
- Sorting algorithms (quicksort, mergesort, bubblesort)
- Compression algorithms (ZIP, RAR, TAR)
- Route planning (shortest, fastest, scenic)
- Validation strategies
Implementation in C++:
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include <algorithm>
// Strategy interface
class SortStrategy {
public:
virtual ~SortStrategy() = default;
virtual void sort(std::vector<int>& data) = 0;
virtual std::string getName() const = 0;
};
// Concrete Strategies
class BubbleSort : public SortStrategy {
public:
void sort(std::vector<int>& data) override {
std::cout << "Sorting using Bubble Sort" << std::endl;
for (size_t i = 0; i < data.size(); ++i) {
for (size_t j = 0; j < data.size() - i - 1; ++j) {
if (data[j] > data[j + 1]) {
std::swap(data[j], data[j + 1]);
}
}
}
}
std::string getName() const override {
return "Bubble Sort";
}
};
class QuickSort : public SortStrategy {
public:
void sort(std::vector<int>& data) override {
std::cout << "Sorting using Quick Sort" << std::endl;
std::sort(data.begin(), data.end());
}
std::string getName() const override {
return "Quick Sort";
}
};
class MergeSort : public SortStrategy {
public:
void sort(std::vector<int>& data) override {
std::cout << "Sorting using Merge Sort" << std::endl;
std::stable_sort(data.begin(), data.end());
}
std::string getName() const override {
return "Merge Sort";
}
};
// Context
class DataSorter {
public:
void setStrategy(std::unique_ptr<SortStrategy> strategy) {
strategy_ = std::move(strategy);
}
void sort(std::vector<int>& data) {
if (strategy_) {
strategy_->sort(data);
} else {
std::cout << "No sorting strategy set!" << std::endl;
}
}
private:
std::unique_ptr<SortStrategy> strategy_;
};
// Payment example
class PaymentStrategy {
public:
virtual ~PaymentStrategy() = default;
virtual void pay(double amount) = 0;
};
class CreditCardPayment : public PaymentStrategy {
public:
CreditCardPayment(const std::string& number, const std::string& cvv)
: cardNumber_(number), cvv_(cvv) {}
void pay(double amount) override {
std::cout << "Paid $" << amount << " using Credit Card ending in "
<< cardNumber_.substr(cardNumber_.length() - 4) << std::endl;
}
private:
std::string cardNumber_;
std::string cvv_;
};
class PayPalPayment : public PaymentStrategy {
public:
PayPalPayment(const std::string& email) : email_(email) {}
void pay(double amount) override {
std::cout << "Paid $" << amount << " using PayPal account " << email_ << std::endl;
}
private:
std::string email_;
};
class ShoppingCart {
public:
void setPaymentStrategy(std::unique_ptr<PaymentStrategy> strategy) {
paymentStrategy_ = std::move(strategy);
}
void checkout(double amount) {
if (paymentStrategy_) {
paymentStrategy_->pay(amount);
}
}
private:
std::unique_ptr<PaymentStrategy> paymentStrategy_;
};
// Usage
int main() {
// Sorting example
std::vector<int> data = {64, 34, 25, 12, 22, 11, 90};
DataSorter sorter;
sorter.setStrategy(std::make_unique<BubbleSort>());
auto data1 = data;
sorter.sort(data1);
sorter.setStrategy(std::make_unique<QuickSort>());
auto data2 = data;
sorter.sort(data2);
std::cout << "\n---\n\n";
// Payment example
ShoppingCart cart;
cart.setPaymentStrategy(std::make_unique<CreditCardPayment>("1234567890123456", "123"));
cart.checkout(100.0);
cart.setPaymentStrategy(std::make_unique<PayPalPayment>("user@example.com"));
cart.checkout(50.0);
return 0;
}
Implementation in Python:
from abc import ABC, abstractmethod
from typing import List
# Strategy interface
class SortStrategy(ABC):
@abstractmethod
def sort(self, data: List[int]) -> None:
pass
@abstractmethod
def get_name(self) -> str:
pass
# Concrete Strategies
class BubbleSort(SortStrategy):
def sort(self, data: List[int]) -> None:
print("Sorting using Bubble Sort")
n = len(data)
for i in range(n):
for j in range(0, n - i - 1):
if data[j] > data[j + 1]:
data[j], data[j + 1] = data[j + 1], data[j]
def get_name(self) -> str:
return "Bubble Sort"
class QuickSort(SortStrategy):
def sort(self, data: List[int]) -> None:
print("Sorting using Quick Sort")
data.sort()
def get_name(self) -> str:
return "Quick Sort"
# Context
class DataSorter:
def __init__(self, strategy: SortStrategy = None):
self._strategy = strategy
def set_strategy(self, strategy: SortStrategy) -> None:
self._strategy = strategy
def sort(self, data: List[int]) -> None:
if self._strategy:
self._strategy.sort(data)
else:
print("No sorting strategy set!")
# Payment example
class PaymentStrategy(ABC):
@abstractmethod
def pay(self, amount: float) -> None:
pass
class CreditCardPayment(PaymentStrategy):
def __init__(self, card_number: str, cvv: str):
self.card_number = card_number
self.cvv = cvv
def pay(self, amount: float) -> None:
print(f"Paid ${amount} using Credit Card ending in {self.card_number[-4:]}")
class PayPalPayment(PaymentStrategy):
def __init__(self, email: str):
self.email = email
def pay(self, amount: float) -> None:
print(f"Paid ${amount} using PayPal account {self.email}")
class ShoppingCart:
def __init__(self):
self._payment_strategy: PaymentStrategy = None
def set_payment_strategy(self, strategy: PaymentStrategy) -> None:
self._payment_strategy = strategy
def checkout(self, amount: float) -> None:
if self._payment_strategy:
self._payment_strategy.pay(amount)
# Usage
if __name__ == "__main__":
# Sorting
data = [64, 34, 25, 12, 22, 11, 90]
sorter = DataSorter()
sorter.set_strategy(BubbleSort())
data1 = data.copy()
sorter.sort(data1)
sorter.set_strategy(QuickSort())
data2 = data.copy()
sorter.sort(data2)
print("\n---\n")
# Payment
cart = ShoppingCart()
cart.set_payment_strategy(CreditCardPayment("1234567890123456", "123"))
cart.checkout(100.0)
cart.set_payment_strategy(PayPalPayment("user@example.com"))
cart.checkout(50.0)
Advantages:
- Families of related algorithms can be reused
- Open/Closed Principle: introduce new strategies without changing context
- Runtime algorithm switching
- Isolates algorithm implementation from code that uses it
- Eliminates conditional statements
Disadvantages:
- Clients must be aware of different strategies
- Increases number of objects
- All strategies must expose same interface (even if some don't use all parameters)
Related Patterns:
- State: Both encapsulate behavior; Strategy focuses on algorithm, State on object state
- Template Method: Uses inheritance; Strategy uses composition
- Factory Method: Often used to create appropriate strategy
Command Pattern
Intent: Encapsulate a request as an object, thereby letting you parameterize clients with different requests, queue or log requests, and support undoable operations.
Problem: You need to issue requests to objects without knowing anything about the operation being requested or the receiver of the request. You want to support undo/redo, queuing, or logging of operations.
Solution: Create command objects that encapsulate all information needed to perform an action or trigger an event. Commands have an execute() method and optionally an undo() method.
When to Use:
- Parameterize objects by an action to perform
- Queue operations, schedule their execution, or execute them remotely
- Support undo/redo functionality
- Structure system around high-level operations built on primitive operations
- Support logging changes for crash recovery
Real-World Examples:
- GUI buttons and menu items
- Macro recording in applications
- Transaction-based systems
- Task scheduling systems
- Undo/redo in text editors
- Remote control systems
Implementation in C++:
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include <stack>
// Receiver
class Light {
public:
void on() {
isOn_ = true;
std::cout << "Light is ON" << std::endl;
}
void off() {
isOn_ = false;
std::cout << "Light is OFF" << std::endl;
}
bool isOn() const { return isOn_; }
private:
bool isOn_ = false;
};
// Command interface
class Command {
public:
virtual ~Command() = default;
virtual void execute() = 0;
virtual void undo() = 0;
};
// Concrete Commands
class LightOnCommand : public Command {
public:
LightOnCommand(std::shared_ptr<Light> light) : light_(light) {}
void execute() override {
light_->on();
}
void undo() override {
light_->off();
}
private:
std::shared_ptr<Light> light_;
};
class LightOffCommand : public Command {
public:
LightOffCommand(std::shared_ptr<Light> light) : light_(light) {}
void execute() override {
light_->off();
}
void undo() override {
light_->on();
}
private:
std::shared_ptr<Light> light_;
};
// Text editor example
class TextEditor {
public:
void insertText(const std::string& text) {
content_ += text;
std::cout << "Inserted: " << text << std::endl;
}
void deleteText(size_t length) {
if (length <= content_.length()) {
deletedText_ = content_.substr(content_.length() - length);
content_ = content_.substr(0, content_.length() - length);
std::cout << "Deleted: " << deletedText_ << std::endl;
}
}
std::string getDeletedText() const { return deletedText_; }
std::string getContent() const { return content_; }
void print() const {
std::cout << "Content: \"" << content_ << "\"" << std::endl;
}
private:
std::string content_;
std::string deletedText_;
};
class InsertCommand : public Command {
public:
InsertCommand(std::shared_ptr<TextEditor> editor, const std::string& text)
: editor_(editor), text_(text) {}
void execute() override {
editor_->insertText(text_);
}
void undo() override {
editor_->deleteText(text_.length());
}
private:
std::shared_ptr<TextEditor> editor_;
std::string text_;
};
class DeleteCommand : public Command {
public:
DeleteCommand(std::shared_ptr<TextEditor> editor, size_t length)
: editor_(editor), length_(length) {}
void execute() override {
editor_->deleteText(length_);
deletedText_ = editor_->getDeletedText();
}
void undo() override {
editor_->insertText(deletedText_);
}
private:
std::shared_ptr<TextEditor> editor_;
size_t length_;
std::string deletedText_;
};
// Invoker
class RemoteControl {
public:
void setCommand(std::shared_ptr<Command> command) {
command_ = command;
}
void pressButton() {
if (command_) {
command_->execute();
history_.push(command_);
}
}
void pressUndo() {
if (!history_.empty()) {
auto command = history_.top();
command->undo();
history_.pop();
}
}
private:
std::shared_ptr<Command> command_;
std::stack<std::shared_ptr<Command>> history_;
};
// Usage
int main() {
// Light control example
auto livingRoomLight = std::make_shared<Light>();
auto lightOn = std::make_shared<LightOnCommand>(livingRoomLight);
auto lightOff = std::make_shared<LightOffCommand>(livingRoomLight);
RemoteControl remote;
remote.setCommand(lightOn);
remote.pressButton();
remote.setCommand(lightOff);
remote.pressButton();
remote.pressUndo(); // Undo last command
std::cout << "\n---\n\n";
// Text editor example
auto editor = std::make_shared<TextEditor>();
std::stack<std::shared_ptr<Command>> commandHistory;
auto insertHello = std::make_shared<InsertCommand>(editor, "Hello ");
insertHello->execute();
commandHistory.push(insertHello);
auto insertWorld = std::make_shared<InsertCommand>(editor, "World!");
insertWorld->execute();
commandHistory.push(insertWorld);
editor->print();
// Undo last two commands
while (!commandHistory.empty()) {
commandHistory.top()->undo();
commandHistory.pop();
}
editor->print();
return 0;
}
Implementation in Python:
from abc import ABC, abstractmethod
from typing import List
# Receiver
class Light:
def __init__(self):
self._is_on = False
def on(self) -> None:
self._is_on = True
print("Light is ON")
def off(self) -> None:
self._is_on = False
print("Light is OFF")
def is_on(self) -> bool:
return self._is_on
# Command interface
class Command(ABC):
@abstractmethod
def execute(self) -> None:
pass
@abstractmethod
def undo(self) -> None:
pass
# Concrete Commands
class LightOnCommand(Command):
def __init__(self, light: Light):
self.light = light
def execute(self) -> None:
self.light.on()
def undo(self) -> None:
self.light.off()
class LightOffCommand(Command):
def __init__(self, light: Light):
self.light = light
def execute(self) -> None:
self.light.off()
def undo(self) -> None:
self.light.on()
# Text editor
class TextEditor:
def __init__(self):
self._content = ""
self._deleted_text = ""
def insert_text(self, text: str) -> None:
self._content += text
print(f"Inserted: {text}")
def delete_text(self, length: int) -> None:
if length <= len(self._content):
self._deleted_text = self._content[-length:]
self._content = self._content[:-length]
print(f"Deleted: {self._deleted_text}")
def get_deleted_text(self) -> str:
return self._deleted_text
def get_content(self) -> str:
return self._content
def print_content(self) -> None:
print(f"Content: \"{self._content}\"")
class InsertCommand(Command):
def __init__(self, editor: TextEditor, text: str):
self.editor = editor
self.text = text
def execute(self) -> None:
self.editor.insert_text(self.text)
def undo(self) -> None:
self.editor.delete_text(len(self.text))
# Invoker
class RemoteControl:
def __init__(self):
self._command: Command = None
self._history: List[Command] = []
def set_command(self, command: Command) -> None:
self._command = command
def press_button(self) -> None:
if self._command:
self._command.execute()
self._history.append(self._command)
def press_undo(self) -> None:
if self._history:
command = self._history.pop()
command.undo()
# Usage
if __name__ == "__main__":
# Light control
living_room_light = Light()
light_on = LightOnCommand(living_room_light)
light_off = LightOffCommand(living_room_light)
remote = RemoteControl()
remote.set_command(light_on)
remote.press_button()
remote.set_command(light_off)
remote.press_button()
remote.press_undo() # Undo
print("\n---\n")
# Text editor
editor = TextEditor()
command_history = []
insert_hello = InsertCommand(editor, "Hello ")
insert_hello.execute()
command_history.append(insert_hello)
insert_world = InsertCommand(editor, "World!")
insert_world.execute()
command_history.append(insert_world)
editor.print_content()
# Undo
while command_history:
command_history.pop().undo()
editor.print_content()
Advantages:
- Decouples object that invokes operation from one that knows how to perform it
- Commands are first-class objects (can be manipulated and extended)
- Can assemble commands into composite commands (macro commands)
- Easy to add new commands (Open/Closed Principle)
- Supports undo/redo
Disadvantages:
- Increases number of classes for each individual command
- Can become complex with many commands
Related Patterns:
- Memento: Can be used to keep state for undo
- Composite: Can be used to implement macro commands
- Prototype: Commands that must be copied before being placed on history list
Conclusion
Design patterns are invaluable tools for software developers, providing standardized solutions to recurring design problems. By understanding and applying appropriate design patterns, developers can create more flexible, reusable, and maintainable codebases.
Key Takeaways:
-
Choose the Right Pattern: Not every problem requires a design pattern. Use patterns when they genuinely simplify your design.
-
Understand the Trade-offs: Each pattern has advantages and disadvantages. Consider the complexity vs. flexibility trade-off.
-
Patterns Work Together: Many real-world applications combine multiple patterns. For example, MVC uses Observer, Strategy, and Composite patterns.
-
Start Simple: Don't over-engineer. Refactor towards patterns when the need becomes clear.
-
Language Matters: Some patterns are more natural in certain programming languages. For instance, Strategy pattern is trivial in languages with first-class functions.
Common Pattern Categories:
- Creational (Singleton, Factory Method, Abstract Factory, Builder, Prototype): Object creation mechanisms
- Structural (Adapter, Bridge, Composite, Decorator, Facade, Flyweight, Proxy): Object composition and relationships
- Behavioral (Observer, Strategy, Command, and others): Communication between objects
This guide has covered the most fundamental and widely-used design patterns with comprehensive examples in both C++ and Python. Each pattern includes practical implementations, real-world use cases, and guidance on when to apply them. By mastering these patterns, you'll be better equipped to design robust, maintainable, and scalable software systems.
Linux Documentation
A comprehensive guide to Linux system administration, commands, kernel architecture, and networking.
Table of Contents
- Essential Commands - Command reference and examples
- Kernel Architecture - Linux kernel internals and development
- Kernel Development Patterns - Common patterns and best practices for kernel development
- cfg80211 & mac80211 - Wireless subsystem frameworks for WiFi drivers
- Driver Development - Linux driver model and device driver development
- Device Tree - Hardware description using Device Tree
- Cross Compilation - Building for different architectures
- Networking - Network configuration and troubleshooting
- Netfilter - Packet filtering framework
- iptables - Firewall configuration
- Traffic Control (tc) - Network traffic management
- systemd - Service management and init system
- sysctl - Kernel parameter tuning at runtime
- sysfs - Kernel/hardware information filesystem
- Netlink - Kernel-userspace communication interface
- eBPF - Extended Berkeley Packet Filter for kernel programmability
Overview
This documentation covers essential Linux topics for system administrators, developers, and power users. Each section provides practical examples, use cases, and best practices.
Getting Started
For Beginners
Start with Essential Commands to learn the fundamental Linux commands that you'll use daily.
For System Administrators
- Essential Commands - Master command-line tools
- Networking - Network configuration and diagnostics
- iptables - Firewall management
For Developers
- Kernel Architecture - Understand Linux internals
- Kernel Development Patterns - Coding patterns and best practices
- Driver Development - Linux driver model and device drivers
- Device Tree - Hardware description and parsing
- Cross Compilation - Building for embedded systems
- cfg80211 & mac80211 - Wireless driver development
- Essential Commands - Development and debugging tools
For Network Engineers
- Networking - Network stack and protocols
- cfg80211 & mac80211 - Wireless networking subsystem
- Netfilter - Packet filtering framework
- Traffic Control - QoS and traffic shaping
Key Topics
System Administration
- User and permission management
- Process management and monitoring
- System resource monitoring
- Service management with systemd
- Log management and analysis
Kernel Development
- Kernel architecture and components
- System calls and kernel modules
- Device drivers
- Kernel compilation and debugging
Networking
- Network configuration (ip, ifconfig)
- Routing and bridging
- Packet filtering (iptables, nftables)
- Traffic shaping and QoS
- Network troubleshooting
Quick Reference
Most Used Commands
# File operations
ls -lah # List files with details
cd /path/to/directory # Change directory
cp -r source dest # Copy recursively
mv source dest # Move/rename
rm -rf directory # Remove recursively
# Text processing
grep pattern file # Search for pattern
sed 's/old/new/g' file # Replace text
awk '{print $1}' file # Process columns
# System monitoring
top # Process viewer
htop # Enhanced process viewer
ps aux # List all processes
df -h # Disk usage
free -h # Memory usage
# Network
ip addr show # Show IP addresses
ss -tulpn # Show listening ports
ping host # Test connectivity
curl url # HTTP client
System Information
uname -a # Kernel version
lsb_release -a # Distribution info
hostnamectl # System hostname
uptime # System uptime
Learning Path
-
Basics (1-2 weeks)
- File system navigation
- File manipulation
- Text editors (vim, nano)
- Basic shell scripting
-
Intermediate (2-4 weeks)
- Process management
- User management
- Permissions and ownership
- Package management
- System services
-
Advanced (1-3 months)
- Kernel modules
- Network configuration
- Firewall rules
- Performance tuning
- Security hardening
-
Expert (3-6 months)
- Kernel development
- Custom modules
- Advanced networking
- High availability systems
- Container orchestration
Best Practices
Security
- Always use sudo instead of root login
- Keep system and packages updated
- Use SSH keys instead of passwords
- Enable and configure firewall
- Regular security audits
- Monitor system logs
Performance
- Monitor system resources regularly
- Use appropriate file systems
- Optimize kernel parameters
- Implement proper backup strategies
- Use automation tools
Documentation
- Document custom configurations
- Keep change logs
- Use version control for configs
- Create runbooks for common tasks
Useful Resources
Official Documentation
Community Resources
Books
- "The Linux Command Line" by William Shotts
- "Linux Kernel Development" by Robert Love
- "UNIX and Linux System Administration Handbook"
Contributing
When adding new documentation:
- Follow the existing structure
- Include practical examples
- Add use cases and scenarios
- Reference related sections
- Keep examples tested and working
Version Information
- Documentation maintained for Linux Kernel 5.x and 6.x
- Examples tested on Ubuntu 20.04/22.04 and Debian 11/12
- Command syntax may vary slightly between distributions
Networking
TUN and TAP Interfaces
TUN and TAP are virtual network kernel interfaces. They are used to create network interfaces that operate at different layers of the network stack.
TUN Interface
A TUN (network TUNnel) interface is a virtual point-to-point network device that operates at the network layer (Layer 3). It is used to route IP packets. TUN interfaces are commonly used in VPN (Virtual Private Network) implementations to tunnel IP traffic over a secure connection.
Key Features of TUN Interface:
- Operates at Layer 3 (Network Layer).
- Handles IP packets.
- Used for routing and tunneling IP traffic.
- Commonly used in VPNs.
Example Use Case:
A TUN interface can be used to create a secure VPN connection between two remote networks, allowing them to communicate as if they were on the same local network.
TAP Interface
A TAP (network TAP) interface is a virtual network device that operates at the data link layer (Layer 2). It is used to handle Ethernet frames. TAP interfaces are useful for creating network bridges and for virtual machine networking.
Key Features of TAP Interface:
- Operates at Layer 2 (Data Link Layer).
- Handles Ethernet frames.
- Used for bridging and virtual machine networking.
- Can be used to create virtual switches.
Example Use Case:
A TAP interface can be used to connect a virtual machine to a virtual switch, allowing the virtual machine to communicate with other virtual machines and the host system as if they were connected to a physical Ethernet switch.
Creating TUN and TAP Interfaces
TUN and TAP interfaces can be created and managed using the ip command or the tunctl utility. Here is an example of how to create a TUN interface using the ip command:
Linux Kernel Architecture
A comprehensive guide to Linux kernel internals, architecture, system calls, modules, compilation, and debugging.
Table of Contents
- Kernel Overview
- Kernel Architecture
- Memory Management
- Process Management
- System Calls
- Kernel Modules
- Device Drivers
- File Systems
- Networking Stack
- Kernel Compilation
- Kernel Debugging
- Performance Tuning
Kernel Overview
The Linux kernel is a monolithic kernel that handles all system operations including process management, memory management, device drivers, and system calls.
Kernel Architecture Types
Monolithic Kernel (Linux)
- All services run in kernel space
- Better performance (no context switching)
- Single address space
- Larger kernel size
Microkernel
- Minimal kernel (IPC, memory, scheduling)
- Services run in user space
- Better stability and security
- More context switches
Hybrid Kernel
- Combination of both approaches
- Examples: Windows NT, macOS
Linux Kernel Features
- Preemptive multitasking
- Symmetric multiprocessing (SMP)
- Virtual memory management
- Loadable kernel modules
- Multiple filesystem support
- POSIX compliance
- Dynamic kernel memory allocation
- Networking stack (TCP/IP, IPv6)
- Advanced security features (SELinux, AppArmor)
- Real-time capabilities (PREEMPT_RT)
Kernel Version Numbering
# Check kernel version
uname -r
# Output: 6.5.0-15-generic
# Format: MAJOR.MINOR.PATCH-BUILD-ARCH
# 6.5.0 - kernel version
# 15 - distribution build number
# generic - kernel flavor/variant
Version Types:
- Mainline - Latest features, active development
- Stable - Production-ready, bug fixes only
- LTS (Long Term Support) - Extended maintenance (2-6 years)
- EOL (End of Life) - No longer maintained
Kernel Source Tree Structure
/usr/src/linux/
├── arch/ # Architecture-specific code (x86, ARM, etc.)
├── block/ # Block device drivers
├── crypto/ # Cryptographic API
├── Documentation/ # Kernel documentation
├── drivers/ # Device drivers
│ ├── char/ # Character devices
│ ├── block/ # Block devices
│ ├── net/ # Network devices
│ ├── gpu/ # Graphics drivers
│ └── usb/ # USB drivers
├── fs/ # File system implementations
│ ├── ext4/ # ext4 filesystem
│ ├── btrfs/ # Btrfs filesystem
│ └── nfs/ # Network file system
├── include/ # Header files
│ ├── linux/ # Linux-specific headers
│ └── uapi/ # User-space API headers
├── init/ # Kernel initialization
├── ipc/ # Inter-process communication
├── kernel/ # Core kernel code
│ ├── sched/ # Process scheduler
│ ├── time/ # Time management
│ └── irq/ # Interrupt handling
├── lib/ # Library routines
├── mm/ # Memory management
├── net/ # Networking stack
│ ├── ipv4/ # IPv4 implementation
│ ├── ipv6/ # IPv6 implementation
│ └── core/ # Core networking
├── samples/ # Sample code
├── scripts/ # Build scripts
├── security/ # Security modules (SELinux, AppArmor)
├── sound/ # Sound drivers
└── tools/ # Kernel tools and utilities
Kernel Architecture
Kernel Space vs User Space
+------------------------------------------+
| User Space (Ring 3) |
| +--------------------------------------+ |
| | User Applications | |
| | (web browsers, editors, games, etc.) | |
| +--------------------------------------+ |
| ↕ |
| +--------------------------------------+ |
| | System Libraries (glibc, etc.) | |
| +--------------------------------------+ |
+------------------------------------------+
↕
System Call Interface
↕
+------------------------------------------+
| Kernel Space (Ring 0) |
| +--------------------------------------+ |
| | System Call Interface | |
| +--------------------------------------+ |
| | Process | Memory | File System | |
| | Management | Manager | Layer | |
| +--------------------------------------+ |
| | Network | IPC | Security | |
| | Stack | Layer | Modules | |
| +--------------------------------------+ |
| | Device Drivers | |
| | (char, block, network) | |
| +--------------------------------------+ |
| | Architecture-Specific Code | |
| | (CPU, MMU, interrupts) | |
| +--------------------------------------+ |
+------------------------------------------+
↕
Hardware Layer
Key Kernel Components
1. Process Scheduler
Manages CPU time allocation among processes.
// Scheduling classes (from highest to lowest priority)
1. SCHED_DEADLINE // Deadline scheduling (real-time)
2. SCHED_FIFO // First-in-first-out (real-time)
3. SCHED_RR // Round-robin (real-time)
4. SCHED_NORMAL // Standard time-sharing (CFS)
5. SCHED_BATCH // Batch processes
6. SCHED_IDLE // Very low priority
// Completely Fair Scheduler (CFS) - default for SCHED_NORMAL
// - Uses red-black tree for O(log n) operations
// - Virtual runtime tracking
// - Fair CPU time distribution
Check and modify scheduling:
# View process scheduling info
ps -eo pid,pri,ni,comm,policy
# Change scheduling policy
chrt -f -p 99 PID # Set to FIFO with priority 99
chrt -r -p 50 PID # Set to Round-robin
chrt -o -p 0 PID # Set to normal
# Change nice value (-20 to 19)
nice -n 10 command # Run with nice value 10
renice -n 5 -p PID # Change nice value of running process
2. Memory Manager
Handles virtual memory, paging, and memory allocation.
Virtual Memory Layout (64-bit x86):
0x00007FFFFFFFFFFF +------------------+
| User Stack | (grows down)
+------------------+
| Memory Mapped |
| Files & Libs |
+------------------+
| Heap | (grows up)
+------------------+
| BSS (uninit data)|
+------------------+
| Data (init data) |
+------------------+
0x0000000000400000 | Text (code) |
+------------------+
| Reserved |
0x0000000000000000 +------------------+
Kernel Space Layout:
0xFFFFFFFFFFFFFFFF +------------------+
| Kernel Code/Data |
+------------------+
| Direct Mapping |
| (Physical RAM) |
+------------------+
| vmalloc Area |
+------------------+
| Module Space |
0xFFFF800000000000 +------------------+
Memory zones:
ZONE_DMA - Memory for DMA (0-16MB on x86)
ZONE_DMA32 - Memory for 32-bit DMA (0-4GB)
ZONE_NORMAL - Normal memory (above 4GB on 64-bit)
ZONE_HIGHMEM - High memory (not directly mapped, 32-bit only)
ZONE_MOVABLE - Memory that can be migrated
3. Virtual File System (VFS)
Abstract layer for file system operations.
// VFS Objects
struct super_block // Mounted filesystem
struct inode // File metadata
struct dentry // Directory entry (name to inode mapping)
struct file // Open file instance
// File operations structure
struct file_operations {
ssize_t (*read) (struct file *, char __user *, size_t, loff_t *);
ssize_t (*write) (struct file *, const char __user *, size_t, loff_t *);
int (*open) (struct inode *, struct file *);
int (*release) (struct inode *, struct file *);
// ... more operations
};
4. Network Stack
Implements network protocols and socket interface.
Layer Model:
Application Layer
↕
Socket Interface
↕
Transport Layer (TCP/UDP)
↕
Network Layer (IP)
↕
Link Layer (Ethernet, WiFi)
↕
Device Driver
↕
Hardware
Memory Management
Page Management
Linux uses paging for memory management:
# Check page size
getconf PAGE_SIZE
# Usually 4096 bytes (4KB)
# View memory info
cat /proc/meminfo
# MemTotal, MemFree, MemAvailable, Buffers, Cached, etc.
# Memory statistics
vmstat 1
# View paging, memory, CPU stats every second
# Detailed memory usage
cat /proc/PID/status | grep -i vm
cat /proc/PID/maps # Memory mappings
Memory Allocation
Kernel Memory Allocation:
// Physically contiguous memory
kmalloc(size, GFP_KERNEL) // Standard allocation
kfree(ptr) // Free memory
// Virtual contiguous memory
vmalloc(size) // Large allocations
vfree(ptr)
// Page-based allocation
alloc_pages(gfp_mask, order) // 2^order pages
free_pages(addr, order)
// Flags (GFP = Get Free Pages)
GFP_KERNEL // Standard, may sleep
GFP_ATOMIC // Cannot sleep, for interrupts
GFP_USER // User space allocation
GFP_DMA // DMA-capable memory
Memory Reclamation
OOM Killer (Out-of-Memory):
# View OOM score (higher = more likely to be killed)
cat /proc/PID/oom_score
# Adjust OOM score (-1000 to 1000)
echo -500 > /proc/PID/oom_score_adj # Less likely to be killed
echo 500 > /proc/PID/oom_score_adj # More likely to be killed
# Disable OOM killer for process
echo -1000 > /proc/PID/oom_score_adj
# View OOM killer logs
dmesg | grep -i "out of memory"
journalctl -k | grep -i "oom"
Swapping:
# View swap usage
swapon --show
free -h
# Create swap file
dd if=/dev/zero of=/swapfile bs=1M count=1024
chmod 600 /swapfile
mkswap /swapfile
swapon /swapfile
# Control swappiness (0-100, default 60)
cat /proc/sys/vm/swappiness
echo 10 > /proc/sys/vm/swappiness # Less aggressive swapping
# Make permanent in /etc/sysctl.conf
vm.swappiness=10
Huge Pages
Improve performance for applications with large memory footprints:
# View huge page info
cat /proc/meminfo | grep -i huge
# Configure huge pages
echo 512 > /proc/sys/vm/nr_hugepages
# Transparent Huge Pages (THP)
cat /sys/kernel/mm/transparent_hugepage/enabled
echo always > /sys/kernel/mm/transparent_hugepage/enabled
echo madvise > /sys/kernel/mm/transparent_hugepage/enabled # Recommended
echo never > /sys/kernel/mm/transparent_hugepage/enabled
Process Management
Process Representation
// Task structure (Process Control Block)
struct task_struct {
pid_t pid; // Process ID
pid_t tgid; // Thread group ID
struct task_struct *parent; // Parent process
struct list_head children; // Child processes
struct mm_struct *mm; // Memory descriptor
struct fs_struct *fs; // Filesystem info
struct files_struct *files; // Open files
int exit_state; // Exit status
unsigned int policy; // Scheduling policy
// ... many more fields
};
Process States
TASK_RUNNING // Running or ready to run
TASK_INTERRUPTIBLE // Sleeping, can be woken by signals
TASK_UNINTERRUPTIBLE // Sleeping, cannot be interrupted
TASK_STOPPED // Stopped (e.g., by SIGSTOP)
TASK_TRACED // Being traced by debugger
EXIT_ZOMBIE // Terminated, waiting for parent
EXIT_DEAD // Final state before removal
View process states:
ps aux
# STAT column:
# R - Running
# S - Sleeping (interruptible)
# D - Sleeping (uninterruptible, usually I/O)
# T - Stopped
# Z - Zombie
# < - High priority
# N - Low priority
# + - Foreground process group
# Find stuck processes (uninterruptible sleep)
ps aux | awk '$8 ~ /D/'
Process Creation
fork() system call:
#include <unistd.h>
#include <stdio.h>
int main() {
pid_t pid = fork();
if (pid < 0) {
// Fork failed
perror("fork");
return 1;
} else if (pid == 0) {
// Child process
printf("Child: PID = %d\n", getpid());
} else {
// Parent process
printf("Parent: PID = %d, Child PID = %d\n", getpid(), pid);
}
return 0;
}
exec() system call:
#include <unistd.h>
int main() {
char *args[] = {"/bin/ls", "-l", NULL};
execv("/bin/ls", args); // Replace current process
// Only reached if exec fails
perror("exec");
return 1;
}
Process Namespaces
Provide isolation for different resources:
# Namespace types
PID # Process IDs
NET # Network stack
MNT # Mount points
IPC # Inter-process communication
UTS # Hostname and domain name
USER # User and group IDs
CGROUP # Control groups
# View process namespaces
ls -l /proc/self/ns/
lsns # List namespaces
# Create new namespace
unshare --pid --fork bash # New PID namespace
unshare --net bash # New network namespace
# Enter namespace
nsenter --target PID --pid --uts --net bash
System Calls
System calls provide the interface between user space and kernel space.
System Call Mechanism
User Space:
Application calls glibc function
↓
glibc wrapper function
↓
Software interrupt (int 0x80 or syscall instruction)
↓
Kernel Space:
System call handler
↓
Kernel function implementation
↓
Return to user space
Common System Calls
Process Management:
fork() // Create child process
exec() // Execute program
exit() // Terminate process
wait() // Wait for child process
getpid() // Get process ID
getppid() // Get parent process ID
kill() // Send signal to process
nice() // Change priority
File Operations:
open() // Open file
close() // Close file
read() // Read from file
write() // Write to file
lseek() // Change file position
stat() // Get file status
chmod() // Change permissions
chown() // Change ownership
link() // Create hard link
unlink() // Delete file
mkdir() // Create directory
rmdir() // Remove directory
Memory Management:
brk() // Change data segment size
mmap() // Map file or device into memory
munmap() // Unmap memory
mprotect() // Change memory protection
mlock() // Lock memory (prevent swapping)
Networking:
socket() // Create socket
bind() // Bind socket to address
listen() // Listen for connections
accept() // Accept connection
connect() // Connect to remote socket
send() // Send data
recv() // Receive data
shutdown() // Shut down socket
Tracing System Calls
strace - Trace system calls:
# Trace program execution
strace ls
strace -o output.txt ls # Save to file
# Trace specific system calls
strace -e open,read ls # Only open and read
strace -e trace=file ls # All file operations
strace -e trace=network curl example.com
# Attach to running process
strace -p PID
# Count system call statistics
strace -c ls
# Follow child processes
strace -f ./program
# Timestamp system calls
strace -t ls # Time of day
strace -T ls # Time spent in each call
# Examples
strace -e trace=open,openat cat /etc/passwd
strace -c find / -name "*.log" 2>/dev/null
strace -p $(pgrep nginx | head -1)
Writing a Simple System Call
1. Add system call to kernel:
// kernel/sys.c
SYSCALL_DEFINE1(hello, char __user *, msg)
{
char kernel_msg[256];
if (copy_from_user(kernel_msg, msg, sizeof(kernel_msg)))
return -EFAULT;
printk(KERN_INFO "System call hello: %s\n", kernel_msg);
return 0;
}
2. Add to system call table:
// arch/x86/entry/syscalls/syscall_64.tbl
450 common hello sys_hello
3. User space program:
#include <unistd.h>
#include <sys/syscall.h>
#define __NR_hello 450
int main() {
syscall(__NR_hello, "Hello from user space!");
return 0;
}
Kernel Modules
Kernel modules allow dynamic loading of code into the running kernel.
Module Basics
# List loaded modules
lsmod
# Module information
modinfo module_name
# Load module
modprobe module_name
insmod /path/to/module.ko
# Unload module
modprobe -r module_name
rmmod module_name
# Module dependencies
depmod -a
# Module parameters
modinfo -p module_name
modprobe module_name param=value
Writing a Simple Module
hello_module.c:
#include <linux/init.h>
#include <linux/module.h>
#include <linux/kernel.h>
MODULE_LICENSE("GPL");
MODULE_AUTHOR("Your Name");
MODULE_DESCRIPTION("A simple Hello World module");
MODULE_VERSION("1.0");
// Module initialization
static int __init hello_init(void)
{
printk(KERN_INFO "Hello World module loaded\n");
return 0; // 0 = success
}
// Module cleanup
static void __exit hello_exit(void)
{
printk(KERN_INFO "Hello World module unloaded\n");
}
// Register init and exit functions
module_init(hello_init);
module_exit(hello_exit);
Makefile:
obj-m += hello_module.o
all:
make -C /lib/modules/$(shell uname -r)/build M=$(PWD) modules
clean:
make -C /lib/modules/$(shell uname -r)/build M=$(PWD) clean
install:
make -C /lib/modules/$(shell uname -r)/build M=$(PWD) modules_install
depmod -a
Build and load:
# Compile
make
# Load module
sudo insmod hello_module.ko
# Check kernel log
dmesg | tail
# Unload module
sudo rmmod hello_module
# Install system-wide
sudo make install
Module with Parameters
#include <linux/init.h>
#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/moduleparam.h>
MODULE_LICENSE("GPL");
static int count = 1;
static char *name = "World";
module_param(count, int, S_IRUGO);
module_param(name, charp, S_IRUGO);
MODULE_PARM_DESC(count, "Number of times to greet");
MODULE_PARM_DESC(name, "Name to greet");
static int __init param_init(void)
{
int i;
for (i = 0; i < count; i++) {
printk(KERN_INFO "Hello %s! (%d/%d)\n", name, i+1, count);
}
return 0;
}
static void __exit param_exit(void)
{
printk(KERN_INFO "Goodbye %s!\n", name);
}
module_init(param_init);
module_exit(param_exit);
Load with parameters:
sudo insmod param_module.ko count=3 name="Linux"
Character Device Driver
#include <linux/init.h>
#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/fs.h>
#include <linux/uaccess.h>
#define DEVICE_NAME "chardev"
#define BUFFER_SIZE 1024
MODULE_LICENSE("GPL");
static int major_number;
static char device_buffer[BUFFER_SIZE];
static int buffer_size = 0;
// File operations
static int dev_open(struct inode *inodep, struct file *filep)
{
printk(KERN_INFO "chardev: Device opened\n");
return 0;
}
static ssize_t dev_read(struct file *filep, char *buffer,
size_t len, loff_t *offset)
{
int bytes_read = 0;
if (*offset >= buffer_size)
return 0;
bytes_read = buffer_size - *offset;
if (bytes_read > len)
bytes_read = len;
if (copy_to_user(buffer, device_buffer + *offset, bytes_read))
return -EFAULT;
*offset += bytes_read;
return bytes_read;
}
static ssize_t dev_write(struct file *filep, const char *buffer,
size_t len, loff_t *offset)
{
int bytes_written = len;
if (bytes_written > BUFFER_SIZE)
bytes_written = BUFFER_SIZE;
if (copy_from_user(device_buffer, buffer, bytes_written))
return -EFAULT;
buffer_size = bytes_written;
printk(KERN_INFO "chardev: Received %d bytes\n", bytes_written);
return bytes_written;
}
static int dev_release(struct inode *inodep, struct file *filep)
{
printk(KERN_INFO "chardev: Device closed\n");
return 0;
}
static struct file_operations fops = {
.open = dev_open,
.read = dev_read,
.write = dev_write,
.release = dev_release,
};
static int __init chardev_init(void)
{
major_number = register_chrdev(0, DEVICE_NAME, &fops);
if (major_number < 0) {
printk(KERN_ALERT "chardev: Failed to register\n");
return major_number;
}
printk(KERN_INFO "chardev: Registered with major number %d\n",
major_number);
printk(KERN_INFO "chardev: Create device with: mknod /dev/%s c %d 0\n",
DEVICE_NAME, major_number);
return 0;
}
static void __exit chardev_exit(void)
{
unregister_chrdev(major_number, DEVICE_NAME);
printk(KERN_INFO "chardev: Unregistered\n");
}
module_init(chardev_init);
module_exit(chardev_exit);
Using the device:
# Load module
sudo insmod chardev.ko
# Create device node
sudo mknod /dev/chardev c <major_number> 0
sudo chmod 666 /dev/chardev
# Test device
echo "Hello" > /dev/chardev
cat /dev/chardev
# Cleanup
sudo rm /dev/chardev
sudo rmmod chardev
Device Drivers
Driver Types
Character Devices:
- Sequential access
- Examples: keyboards, serial ports, /dev/null
- Major/minor numbers for identification
Block Devices:
- Random access, buffered I/O
- Examples: hard drives, SSDs, USB drives
- Use page cache for performance
Network Devices:
- Packet transmission/reception
- Examples: Ethernet, WiFi, loopback
- Socket interface
Device Model
# View device hierarchy
ls /sys/devices/
ls /sys/class/
# PCI devices
lspci -v
ls /sys/bus/pci/devices/
# USB devices
lsusb -v
ls /sys/bus/usb/devices/
# Block devices
lsblk
ls /sys/block/
# Network devices
ip link show
ls /sys/class/net/
# Device information
udevadm info --query=all --name=/dev/sda
Device Management with udev
udev rules (/etc/udev/rules.d/):
# Example: Custom USB device rule
# /etc/udev/rules.d/99-usb-device.rules
SUBSYSTEM=="usb", ATTR{idVendor}=="1234", ATTR{idProduct}=="5678", \
MODE="0666", GROUP="users", SYMLINK+="mydevice"
# Reload udev rules
sudo udevadm control --reload-rules
sudo udevadm trigger
# Monitor udev events
udevadm monitor
File Systems
VFS Layer
The Virtual File System provides a common interface for all filesystems.
Supported filesystems:
cat /proc/filesystems
# ext4, btrfs, xfs, nfs, vfat, tmpfs, etc.
# Filesystem modules
ls /lib/modules/$(uname -r)/kernel/fs/
ext4 Filesystem
# Create ext4 filesystem
mkfs.ext4 /dev/sdb1
# Filesystem check
fsck.ext4 /dev/sdb1
e2fsck -f /dev/sdb1
# Filesystem information
dumpe2fs /dev/sdb1
tune2fs -l /dev/sdb1
# Tune filesystem
tune2fs -m 1 /dev/sdb1 # Reserved blocks percentage
tune2fs -c 30 /dev/sdb1 # Max mount count
tune2fs -i 6m /dev/sdb1 # Check interval
# Enable/disable features
tune2fs -O has_journal /dev/sdb1 # Enable journaling
tune2fs -O ^has_journal /dev/sdb1 # Disable journaling
Filesystem Debugging
# Debugfs - interactive ext2/ext3/ext4 debugger
debugfs /dev/sdb1
# Commands: ls, cd, stat, logdump, etc.
# View inode information
stat /path/to/file
ls -i /path/to/file # Show inode number
debugfs -R "stat <inode_number>" /dev/sdb1
# Find deleted files
debugfs -R "lsdel" /dev/sdb1
Networking Stack
Network Layer Architecture
+-----------------+
| Application |
+-----------------+
| Socket Layer |
+-----------------+
| Protocol Layer | (TCP, UDP, ICMP)
+-----------------+
| IP Layer | (IPv4, IPv6, routing)
+-----------------+
| Link Layer | (Ethernet, WiFi)
+-----------------+
| Device Driver |
+-----------------+
| Hardware |
+-----------------+
Network Configuration
# View network configuration
ip addr show
ip route show
ip link show
# Network statistics
cat /proc/net/dev # Interface statistics
cat /proc/net/tcp # TCP connections
cat /proc/net/udp # UDP connections
netstat -s # Protocol statistics
# Socket buffers
sysctl net.core.rmem_max # Receive buffer
sysctl net.core.wmem_max # Send buffer
# TCP parameters
sysctl net.ipv4.tcp_rmem # TCP receive memory
sysctl net.ipv4.tcp_wmem # TCP send memory
sysctl net.ipv4.tcp_congestion_control
Network Debugging
See networking.md for detailed network debugging.
Kernel Compilation
Getting Kernel Source
# Download from kernel.org
wget https://cdn.kernel.org/pub/linux/kernel/v6.x/linux-6.5.tar.xz
tar -xf linux-6.5.tar.xz
cd linux-6.5
# Or use git
git clone https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
cd linux
git checkout v6.5
# Distribution specific
# Ubuntu/Debian
apt-get source linux-image-$(uname -r)
# Fedora/RHEL
dnf download --source kernel
Kernel Configuration
cd /usr/src/linux
# Configuration methods
make config # Text-based Q&A (tedious)
make menuconfig # Text-based menu (ncurses)
make xconfig # Qt-based GUI
make gconfig # GTK-based GUI
# Use existing config
make oldconfig # Update old config
make localmodconfig # Only modules for current hardware
make defconfig # Default configuration
cp /boot/config-$(uname -r) .config # Copy running config
# Configuration file
.config # Generated configuration
Important config options:
# General setup
CONFIG_LOCALVERSION="-custom" # Custom kernel name
CONFIG_DEFAULT_HOSTNAME="myhost"
# Processor type
CONFIG_SMP=y # Symmetric multiprocessing
CONFIG_NR_CPUS=8 # Number of CPUs
# Power management
CONFIG_CPU_FREQ=y # CPU frequency scaling
CONFIG_HIBERNATION=y
# Networking
CONFIG_NETFILTER=y # Firewall support
CONFIG_BRIDGE=y # Network bridging
# Filesystems
CONFIG_EXT4_FS=y # ext4 filesystem
CONFIG_BTRFS_FS=y # Btrfs filesystem
# Security
CONFIG_SECURITY_SELINUX=y # SELinux support
CONFIG_SECURITY_APPARMOR=y # AppArmor support
# Debugging
CONFIG_DEBUG_KERNEL=y # Kernel debugging
CONFIG_KGDB=y # Kernel debugger
CONFIG_DEBUG_INFO=y # Debug symbols
Building the Kernel
# Install build dependencies
# Ubuntu/Debian
sudo apt install build-essential libncurses-dev bison flex \
libssl-dev libelf-dev bc
# Fedora/RHEL
sudo dnf groupinstall "Development Tools"
sudo dnf install ncurses-devel bison flex elfutils-libelf-devel \
openssl-devel bc
# Build kernel
make -j$(nproc) # Use all CPU cores
# Or build specific targets
make bzImage # Kernel image
make modules # Kernel modules
make dtbs # Device tree blobs (ARM)
# Install
sudo make modules_install # Install modules to /lib/modules
sudo make install # Install kernel to /boot
# Manual installation
sudo cp arch/x86/boot/bzImage /boot/vmlinuz-6.5-custom
sudo cp System.map /boot/System.map-6.5-custom
sudo cp .config /boot/config-6.5-custom
# Update bootloader
sudo update-grub # Debian/Ubuntu
sudo grub2-mkconfig -o /boot/grub2/grub.cfg # Fedora/RHEL
# Reboot
sudo reboot
Cross-Compilation
# Install cross-compiler
sudo apt install gcc-arm-linux-gnueabi
# Configure for target architecture
make ARCH=arm CROSS_COMPILE=arm-linux-gnueabi- defconfig
# Build
make ARCH=arm CROSS_COMPILE=arm-linux-gnueabi- -j$(nproc)
# Example architectures
ARCH=arm # ARM 32-bit
ARCH=arm64 # ARM 64-bit (aarch64)
ARCH=mips # MIPS
ARCH=powerpc # PowerPC
ARCH=riscv # RISC-V
Kernel Patching
# Apply patch
patch -p1 < patch-file.patch
# Create patch
diff -Naur original/ modified/ > my-patch.patch
# Check if patch applies cleanly
patch -p1 --dry-run < patch-file.patch
# Reverse patch
patch -R -p1 < patch-file.patch
Kernel Debugging
printk - Kernel Logging
#include <linux/printk.h>
// Log levels (from highest to lowest priority)
printk(KERN_EMERG "System is unusable\n"); // 0
printk(KERN_ALERT "Action must be taken\n"); // 1
printk(KERN_CRIT "Critical conditions\n"); // 2
printk(KERN_ERR "Error conditions\n"); // 3
printk(KERN_WARNING "Warning conditions\n"); // 4
printk(KERN_NOTICE "Normal but significant\n"); // 5
printk(KERN_INFO "Informational\n"); // 6
printk(KERN_DEBUG "Debug-level messages\n"); // 7
// Default level (usually KERN_WARNING)
printk("Default level message\n");
// Dynamic debug (if CONFIG_DYNAMIC_DEBUG enabled)
pr_debug("Debug message\n");
View kernel messages:
dmesg # View kernel ring buffer
dmesg -w # Follow new messages
dmesg -l err # Only errors
dmesg --level=err,warn # Errors and warnings
dmesg -T # Human-readable timestamps
journalctl -k # Kernel messages via systemd
journalctl -k -f # Follow kernel messages
journalctl -k --since "1 hour ago"
# Set console log level
dmesg -n 1 # Only emergency messages to console
echo 7 > /proc/sys/kernel/printk # All messages to console
KGDB - Kernel Debugger
# Build kernel with debugging enabled
CONFIG_DEBUG_KERNEL=y
CONFIG_DEBUG_INFO=y
CONFIG_KGDB=y
CONFIG_KGDB_SERIAL_CONSOLE=y
# Boot with KGDB enabled
linux ... kgdboc=ttyS0,115200 kgdbwait
# Connect with GDB
gdb vmlinux
(gdb) target remote /dev/ttyS0
(gdb) break sys_open
(gdb) continue
kdump - Kernel Crash Dumps
# Install kdump
# Ubuntu/Debian
sudo apt install kdump-tools
# Fedora/RHEL
sudo dnf install kexec-tools
# Configure kdump
# Edit /etc/default/kdump-tools (Debian) or /etc/sysconfig/kdump (RHEL)
# Reserve memory for crash kernel
# Add to kernel parameters: crashkernel=384M-:128M
# Enable kdump
sudo systemctl enable kdump
sudo systemctl start kdump
# Test crash
echo c > /proc/sysrq-trigger # WARNING: Crashes system!
# Analyze crash dump
crash /usr/lib/debug/vmlinux-<version> /var/crash/vmcore
Magic SysRq Key
Emergency kernel functions:
# Enable SysRq
echo 1 > /proc/sys/kernel/sysrq
# SysRq commands (Alt+SysRq+<key>)
# Or: echo <key> > /proc/sysrq-trigger
b - Reboot immediately
c - Crash (for kdump)
e - SIGTERM to all processes
f - OOM killer
h - Help
i - SIGKILL to all processes
k - Kill all on current console
m - Memory info
p - Current registers and flags
r - Keyboard raw mode
s - Sync all filesystems
t - Task list
u - Remount filesystems read-only
w - Tasks in uninterruptible sleep
# Safe reboot sequence (REISUB)
# R - Raw keyboard mode
# E - SIGTERM all
# I - SIGKILL all
# S - Sync disks
# U - Remount read-only
# B - Reboot
ftrace - Function Tracer
# Mount debugfs
mount -t debugfs none /sys/kernel/debug
cd /sys/kernel/debug/tracing
# Available tracers
cat available_tracers
# function, function_graph, blk, wakeup, etc.
# Enable function tracer
echo function > current_tracer
echo 1 > tracing_on
# View trace
cat trace | head -20
# Stop tracing
echo 0 > tracing_on
# Trace specific function
echo sys_open > set_ftrace_filter
echo function > current_tracer
echo 1 > tracing_on
# Clear trace
echo > trace
# Example: Trace network stack
echo 1 > events/net/enable
echo 1 > tracing_on
# Generate network traffic
cat trace
SystemTap
Dynamic tracing and instrumentation:
# Install SystemTap
sudo apt install systemtap systemtap-runtime
# Install kernel debug symbols
sudo apt install linux-image-$(uname -r)-dbgsym
# Simple script (hello.stp)
probe begin {
printf("Hello, SystemTap!\n")
exit()
}
# Run script
sudo stap hello.stp
# Trace system calls
sudo stap -e 'probe syscall.open { println(execname()) }'
# Count system calls
sudo stap -e '
global count
probe syscall.* {
count[name]++
}
probe end {
foreach (syscall in count-)
printf("%20s: %d\n", syscall, count[syscall])
}
' -c "ls -l /"
perf - Performance Analysis
# Install perf
sudo apt install linux-tools-$(uname -r)
# Record CPU cycles
sudo perf record -a sleep 10
# View report
sudo perf report
# CPU profiling
sudo perf top
# Stat command
sudo perf stat ls -R /
# Trace system calls
sudo perf trace ls
# Record specific events
sudo perf record -e sched:sched_switch -a sleep 5
sudo perf script
# Hardware counters
perf list # List available events
sudo perf stat -e cache-misses,cache-references ls
Performance Tuning
sysctl Parameters
# View all parameters
sysctl -a
# View specific parameter
sysctl vm.swappiness
# Set temporarily
sudo sysctl vm.swappiness=10
# Set permanently (/etc/sysctl.conf or /etc/sysctl.d/)
echo "vm.swappiness=10" | sudo tee -a /etc/sysctl.conf
sudo sysctl -p # Reload configuration
Important parameters:
# Virtual Memory
vm.swappiness=10 # Reduce swap usage
vm.dirty_ratio=10 # Dirty page threshold for writeback
vm.dirty_background_ratio=5 # Background writeback threshold
vm.overcommit_memory=1 # Allow memory overcommit
# Network
net.core.rmem_max=134217728 # Max receive buffer
net.core.wmem_max=134217728 # Max send buffer
net.core.netdev_max_backlog=5000 # Input queue size
net.ipv4.tcp_rmem=4096 87380 67108864 # TCP read memory
net.ipv4.tcp_wmem=4096 65536 67108864 # TCP write memory
net.ipv4.tcp_congestion_control=bbr # Congestion algorithm
net.ipv4.tcp_fastopen=3 # TCP Fast Open
net.ipv4.tcp_mtu_probing=1 # Path MTU discovery
net.ipv4.ip_forward=1 # IP forwarding
# File System
fs.file-max=2097152 # Max open files system-wide
fs.inotify.max_user_watches=524288 # Inotify watches
# Kernel
kernel.sysrq=1 # Enable SysRq
kernel.panic=10 # Reboot 10s after panic
kernel.pid_max=4194304 # Max PIDs
I/O Schedulers
# View available schedulers
cat /sys/block/sda/queue/scheduler
# [mq-deadline] kyber bfq none
# Change scheduler
echo kyber > /sys/block/sda/queue/scheduler
# Schedulers:
# mq-deadline - Default, good for most workloads
# kyber - Low latency, good for SSDs
# bfq - Fair queueing, good for desktops
# none - No scheduling (for NVMe with low latency)
# Make permanent (udev rule)
# /etc/udev/rules.d/60-scheduler.rules
ACTION=="add|change", KERNEL=="sd[a-z]", ATTR{queue/scheduler}="kyber"
CPU Governor
# View current governor
cat /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor
# Available governors
cat /sys/devices/system/cpu/cpu0/cpufreq/scaling_available_governors
# performance powersave schedutil ondemand conservative
# Set governor
echo performance | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor
# Governors:
# performance - Max frequency
# powersave - Min frequency
# ondemand - Dynamic scaling (legacy)
# schedutil - Scheduler-driven (default, recommended)
# conservative - Gradual scaling
# Using cpupower
sudo cpupower frequency-set -g performance
sudo cpupower frequency-info
Huge Pages
# Configure huge pages
echo 512 > /proc/sys/vm/nr_hugepages
# Transparent Huge Pages
echo always > /sys/kernel/mm/transparent_hugepage/enabled
echo madvise > /sys/kernel/mm/transparent_hugepage/enabled # Recommended
# View huge page usage
cat /proc/meminfo | grep -i huge
# Permanent configuration (/etc/sysctl.conf)
vm.nr_hugepages=512
NUMA (Non-Uniform Memory Access)
# Check NUMA configuration
numactl --hardware
# View NUMA statistics
numastat
# Run program on specific NUMA node
numactl --cpunodebind=0 --membind=0 ./program
# Automatic NUMA balancing
echo 1 > /proc/sys/kernel/numa_balancing
Practical Examples
Monitoring System Performance
#!/bin/bash
# System performance monitoring script
echo "=== CPU Usage ==="
mpstat 1 5 | tail -1
echo -e "\n=== Memory Usage ==="
free -h
echo -e "\n=== Disk I/O ==="
iostat -xz 1 2 | tail -n +3
echo -e "\n=== Network ==="
sar -n DEV 1 1 | tail -3
echo -e "\n=== Top Processes by CPU ==="
ps aux --sort=-%cpu | head -6
echo -e "\n=== Top Processes by Memory ==="
ps aux --sort=-%mem | head -6
echo -e "\n=== Load Average ==="
uptime
echo -e "\n=== Kernel Parameters ==="
sysctl vm.swappiness net.ipv4.tcp_congestion_control
Kernel Module Template
/**
* template_module.c - Template for kernel modules
*/
#include <linux/init.h>
#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/slab.h>
MODULE_LICENSE("GPL");
MODULE_AUTHOR("Your Name");
MODULE_DESCRIPTION("Template module");
MODULE_VERSION("1.0");
static int __init template_init(void)
{
printk(KERN_INFO "template: Module loaded\n");
// Initialize your code here
return 0;
}
static void __exit template_exit(void)
{
// Cleanup your code here
printk(KERN_INFO "template: Module unloaded\n");
}
module_init(template_init);
module_exit(template_exit);
Resources
Official Documentation
Books
- "Linux Kernel Development" by Robert Love
- "Linux Device Drivers" by Jonathan Corbet
- "Understanding the Linux Kernel" by Daniel P. Bovet
- "Linux System Programming" by Robert Love
Online Resources
- The Linux Kernel Archives
- LWN.net - Linux Weekly News
- Bootlin Training Materials
Development Tools
- Git - Version control
- cscope/ctags - Code navigation
- sparse - Static analyzer
- Coccinelle - Semantic patching
- QEMU - Virtualization for testing
This guide covers the fundamentals of Linux kernel architecture and development. The kernel is vast and constantly evolving, so continuous learning and experimentation are essential!
Linux Kernel Development Patterns
Common patterns, idioms, and best practices used throughout the Linux kernel codebase.
Table of Contents
- Coding Style
- Design Patterns
- Memory Management Patterns
- Locking and Synchronization
- Error Handling
- Device Driver Patterns
- Data Structures
- Kernel APIs
- Debugging Patterns
- Best Practices
Coding Style
Basic Rules
The Linux kernel has strict coding style guidelines documented in Documentation/process/coding-style.rst.
Indentation and Formatting:
// Use tabs (8 characters) for indentation, not spaces
int function_name(int arg1, int arg2)
{
int local_var;
if (condition) {
do_something();
} else {
do_something_else();
}
return 0;
}
Line Length:
// Prefer 80 columns, maximum 100 columns
// Break long lines sensibly
static const struct file_operations my_fops = {
.owner = THIS_MODULE,
.open = my_open,
.read = my_read,
.write = my_write,
.release = my_release,
};
Naming Conventions:
// Use descriptive, lowercase names with underscores
int count_active_users(struct user_struct *user);
// Global functions should be prefixed with subsystem name
int netdev_register_device(struct net_device *dev);
// Static functions can be shorter
static int validate(void);
// Avoid Hungarian notation
int nr_pages; // Good
int iPageCount; // Bad
Braces:
// Opening brace on same line for functions, structs, etc.
struct my_struct {
int member;
};
// But on next line for functions
int my_function(void)
{
// function body
}
// Single statement doesn't need braces (but be careful)
if (condition)
return -EINVAL;
// Multiple statements always need braces
if (condition) {
do_something();
return 0;
}
Comments
/*
* Multi-line comments use this format.
* Each line starts with a star.
* The closing star-slash is on its own line.
*/
// Single-line comments can use C++ style, but prefer /* */ style
/**
* function_name - Short description
* @param1: Description of param1
* @param2: Description of param2
*
* Longer description of what the function does.
* This can span multiple lines.
*
* Return: Description of return value
*/
int function_name(int param1, char *param2)
{
/* Implementation */
}
Design Patterns
Registration Pattern
The kernel uses registration callbacks extensively for hooking into subsystems.
/* Define operations structure */
struct my_operations {
int (*init)(void);
void (*cleanup)(void);
int (*process)(void *data);
};
/* Define registration structure */
struct my_driver {
const char *name;
struct my_operations *ops;
struct list_head list;
};
/* Registration function */
int register_my_driver(struct my_driver *driver)
{
if (!driver || !driver->ops)
return -EINVAL;
/* Add to global list with locking */
mutex_lock(&drivers_mutex);
list_add_tail(&driver->list, &drivers_list);
mutex_unlock(&drivers_mutex);
/* Initialize if needed */
if (driver->ops->init)
return driver->ops->init();
return 0;
}
/* Unregistration */
void unregister_my_driver(struct my_driver *driver)
{
mutex_lock(&drivers_mutex);
list_del(&driver->list);
mutex_unlock(&drivers_mutex);
if (driver->ops->cleanup)
driver->ops->cleanup();
}
Object-Oriented Patterns in C
The kernel implements inheritance-like patterns using structure embedding.
/* Base "class" */
struct device {
const char *name;
struct device *parent;
void (*release)(struct device *dev);
};
/* Derived "class" */
struct pci_device {
struct device dev; /* Embedded base */
unsigned int vendor;
unsigned int device_id;
};
/* Upcast: derived to base */
struct pci_device *pci_dev;
struct device *dev = &pci_dev->dev;
/* Downcast: base to derived using container_of */
struct device *dev;
struct pci_device *pci_dev = container_of(dev, struct pci_device, dev);
Reference Counting Pattern
struct my_object {
atomic_t refcount;
/* other fields */
};
/* Initialize reference count */
static void my_object_init(struct my_object *obj)
{
atomic_set(&obj->refcount, 1);
}
/* Get reference (increment) */
static inline struct my_object *my_object_get(struct my_object *obj)
{
if (obj)
atomic_inc(&obj->refcount);
return obj;
}
/* Put reference (decrement and free if zero) */
static inline void my_object_put(struct my_object *obj)
{
if (obj && atomic_dec_and_test(&obj->refcount))
my_object_destroy(obj);
}
/* Usage */
struct my_object *obj = my_object_alloc(); /* refcount = 1 */
struct my_object *obj2 = my_object_get(obj); /* refcount = 2 */
my_object_put(obj); /* refcount = 1 */
my_object_put(obj2); /* refcount = 0, object destroyed */
Kernel Object (kobject) Pattern
#include <linux/kobject.h>
struct my_object {
struct kobject kobj;
int value;
};
static struct kobj_type my_ktype = {
.release = my_release,
.sysfs_ops = &my_sysfs_ops,
.default_attrs = my_attrs,
};
/* Create object */
struct my_object *obj = kzalloc(sizeof(*obj), GFP_KERNEL);
kobject_init(&obj->kobj, &my_ktype);
kobject_add(&obj->kobj, parent, "my_object");
/* Get reference */
kobject_get(&obj->kobj);
/* Release reference */
kobject_put(&obj->kobj);
Memory Management Patterns
Allocation Patterns
/* Kernel memory allocation */
void *ptr = kmalloc(size, GFP_KERNEL); /* Can sleep */
void *ptr = kmalloc(size, GFP_ATOMIC); /* Cannot sleep, use in interrupt */
void *ptr = kzalloc(size, GFP_KERNEL); /* Zeroed memory */
/* Large allocations */
void *ptr = vmalloc(size); /* Virtually contiguous, physically may not be */
/* Page allocation */
struct page *page = alloc_page(GFP_KERNEL);
struct page *pages = alloc_pages(GFP_KERNEL, order); /* 2^order pages */
/* Per-CPU variables */
DEFINE_PER_CPU(int, my_var);
int val = get_cpu_var(my_var);
put_cpu_var(my_var);
/* Slab/KMEM cache for frequent allocations */
struct kmem_cache *my_cache;
my_cache = kmem_cache_create("my_cache",
sizeof(struct my_struct),
0, SLAB_HWCACHE_ALIGN, NULL);
struct my_struct *obj = kmem_cache_alloc(my_cache, GFP_KERNEL);
kmem_cache_free(my_cache, obj);
Memory Barriers
/* Compiler barrier - prevent compiler reordering */
barrier();
/* Memory barriers - prevent CPU reordering */
mb(); /* Full memory barrier */
rmb(); /* Read memory barrier */
wmb(); /* Write memory barrier */
smp_mb(); /* SMP memory barrier */
/* Example: Producer-consumer */
/* Producer */
data->value = 42;
smp_wmb(); /* Ensure value is written before flag */
data->ready = 1;
/* Consumer */
while (!data->ready)
cpu_relax();
smp_rmb(); /* Ensure flag is read before value */
value = data->value;
Page Flags and Reference Counting
/* Get a page reference */
get_page(page);
/* Release a page reference */
put_page(page);
/* Check if page is locked */
if (PageLocked(page))
/* ... */
/* Lock a page */
lock_page(page);
unlock_page(page);
/* Page flags */
SetPageDirty(page);
ClearPageDirty(page);
TestSetPageLocked(page);
Locking and Synchronization
Spinlock Pattern
/* Define spinlock */
spinlock_t my_lock;
/* Initialize */
spin_lock_init(&my_lock);
/* Use in process context */
spin_lock(&my_lock);
/* Critical section */
spin_unlock(&my_lock);
/* Use with IRQ disabling (if accessed from interrupt) */
unsigned long flags;
spin_lock_irqsave(&my_lock, flags);
/* Critical section */
spin_unlock_irqrestore(&my_lock, flags);
/* Bottom-half (softirq) protection */
spin_lock_bh(&my_lock);
/* Critical section */
spin_unlock_bh(&my_lock);
Mutex Pattern
/* Define mutex */
struct mutex my_mutex;
/* Initialize */
mutex_init(&my_mutex);
/* Use (can sleep, so only in process context) */
mutex_lock(&my_mutex);
/* Critical section */
mutex_unlock(&my_mutex);
/* Trylock */
if (mutex_trylock(&my_mutex)) {
/* Got the lock */
mutex_unlock(&my_mutex);
}
/* Interruptible lock */
if (mutex_lock_interruptible(&my_mutex))
return -EINTR;
/* Critical section */
mutex_unlock(&my_mutex);
Read-Write Locks
/* Spinlock version */
rwlock_t my_rwlock;
rwlock_init(&my_rwlock);
/* Readers */
read_lock(&my_rwlock);
/* Read data */
read_unlock(&my_rwlock);
/* Writer */
write_lock(&my_rwlock);
/* Modify data */
write_unlock(&my_rwlock);
/* Semaphore version (can sleep) */
struct rw_semaphore my_rwsem;
init_rwsem(&my_rwsem);
down_read(&my_rwsem);
/* Read data */
up_read(&my_rwsem);
down_write(&my_rwsem);
/* Modify data */
up_write(&my_rwsem);
RCU (Read-Copy-Update) Pattern
/* RCU list */
struct my_data {
int value;
struct list_head list;
struct rcu_head rcu;
};
static LIST_HEAD(my_list);
static DEFINE_SPINLOCK(list_lock);
/* Read (no lock needed!) */
rcu_read_lock();
list_for_each_entry_rcu(entry, &my_list, list) {
/* Read entry->value */
}
rcu_read_unlock();
/* Update (needs lock) */
spin_lock(&list_lock);
new = kmalloc(sizeof(*new), GFP_KERNEL);
new->value = 42;
list_add_rcu(&new->list, &my_list);
spin_unlock(&list_lock);
/* Delete */
static void my_data_free(struct rcu_head *head)
{
struct my_data *entry = container_of(head, struct my_data, rcu);
kfree(entry);
}
spin_lock(&list_lock);
list_del_rcu(&entry->list);
spin_unlock(&list_lock);
call_rcu(&entry->rcu, my_data_free); /* Deferred free */
Completion Pattern
/* Declare completion */
struct completion my_completion;
/* Initialize */
init_completion(&my_completion);
/* Wait for completion */
wait_for_completion(&my_completion);
/* Timeout version */
if (!wait_for_completion_timeout(&my_completion, msecs_to_jiffies(5000)))
printk(KERN_ERR "Timeout waiting for completion\n");
/* Signal completion */
complete(&my_completion);
/* Signal all waiters */
complete_all(&my_completion);
Atomic Operations
/* Atomic integer */
atomic_t counter = ATOMIC_INIT(0);
atomic_inc(&counter);
atomic_dec(&counter);
atomic_add(5, &counter);
atomic_sub(3, &counter);
/* Read */
int val = atomic_read(&counter);
/* Set */
atomic_set(&counter, 10);
/* Conditional operations */
if (atomic_dec_and_test(&counter))
/* Counter reached zero */
if (atomic_inc_and_test(&counter))
/* Counter is zero after increment */
/* Compare and swap */
int old = 5;
int new = 10;
atomic_cmpxchg(&counter, old, new);
/* Bitops */
unsigned long flags = 0;
set_bit(0, &flags);
clear_bit(0, &flags);
if (test_bit(0, &flags))
/* Bit is set */
/* Atomic bitops */
test_and_set_bit(0, &flags);
test_and_clear_bit(0, &flags);
Error Handling
Error Code Pattern
/* Return negative error codes, 0 for success */
int my_function(void)
{
if (error_condition)
return -EINVAL; /* Invalid argument */
if (no_memory)
return -ENOMEM; /* Out of memory */
if (timeout)
return -ETIMEDOUT;
return 0; /* Success */
}
/* Caller checks return value */
int ret = my_function();
if (ret) {
printk(KERN_ERR "Function failed: %d\n", ret);
return ret; /* Propagate error */
}
Common Error Codes
-EINVAL /* Invalid argument */
-ENOMEM /* Out of memory */
-EFAULT /* Bad address (copy_from/to_user failed) */
-EBUSY /* Device or resource busy */
-EAGAIN /* Try again (non-blocking operation) */
-EINTR /* Interrupted system call */
-EIO /* I/O error */
-ENODEV /* No such device */
-ENOTTY /* Inappropriate ioctl for device */
-EPERM /* Operation not permitted */
-EACCES /* Permission denied */
-EEXIST /* File exists */
-ENOENT /* No such file or directory */
-ETIMEDOUT /* Connection timed out */
Cleanup with goto Pattern
int complex_function(void)
{
struct resource1 *res1 = NULL;
struct resource2 *res2 = NULL;
struct resource3 *res3 = NULL;
int ret;
res1 = allocate_resource1();
if (!res1) {
ret = -ENOMEM;
goto out;
}
res2 = allocate_resource2();
if (!res2) {
ret = -ENOMEM;
goto free_res1;
}
res3 = allocate_resource3();
if (!res3) {
ret = -ENOMEM;
goto free_res2;
}
/* Do work */
ret = do_work(res1, res2, res3);
if (ret)
goto free_res3;
/* Success path */
return 0;
free_res3:
free_resource3(res3);
free_res2:
free_resource2(res2);
free_res1:
free_resource1(res1);
out:
return ret;
}
ERR_PTR Pattern
/* Return pointer or error */
struct my_struct *my_function(void)
{
struct my_struct *ptr;
ptr = kmalloc(sizeof(*ptr), GFP_KERNEL);
if (!ptr)
return ERR_PTR(-ENOMEM);
if (some_error) {
kfree(ptr);
return ERR_PTR(-EINVAL);
}
return ptr;
}
/* Caller checks for error */
struct my_struct *ptr = my_function();
if (IS_ERR(ptr)) {
int err = PTR_ERR(ptr);
printk(KERN_ERR "Function failed: %d\n", err);
return err;
}
/* Use ptr */
kfree(ptr);
Device Driver Patterns
Character Device Pattern
#include <linux/fs.h>
#include <linux/cdev.h>
static dev_t dev_num;
static struct cdev my_cdev;
static struct class *my_class;
static int my_open(struct inode *inode, struct file *filp)
{
/* Initialize private data */
return 0;
}
static int my_release(struct inode *inode, struct file *filp)
{
/* Cleanup */
return 0;
}
static ssize_t my_read(struct file *filp, char __user *buf,
size_t count, loff_t *pos)
{
/* Read data and copy to user space */
if (copy_to_user(buf, kernel_buf, count))
return -EFAULT;
return count;
}
static ssize_t my_write(struct file *filp, const char __user *buf,
size_t count, loff_t *pos)
{
/* Copy from user space and write */
if (copy_from_user(kernel_buf, buf, count))
return -EFAULT;
return count;
}
static long my_ioctl(struct file *filp, unsigned int cmd, unsigned long arg)
{
switch (cmd) {
case MY_IOCTL_CMD:
/* Handle command */
break;
default:
return -ENOTTY;
}
return 0;
}
static const struct file_operations my_fops = {
.owner = THIS_MODULE,
.open = my_open,
.release = my_release,
.read = my_read,
.write = my_write,
.unlocked_ioctl = my_ioctl,
};
static int __init my_init(void)
{
int ret;
/* Allocate device number */
ret = alloc_chrdev_region(&dev_num, 0, 1, "mydev");
if (ret < 0)
return ret;
/* Initialize cdev */
cdev_init(&my_cdev, &my_fops);
my_cdev.owner = THIS_MODULE;
/* Add cdev */
ret = cdev_add(&my_cdev, dev_num, 1);
if (ret < 0)
goto unregister_chrdev;
/* Create device class */
my_class = class_create(THIS_MODULE, "myclass");
if (IS_ERR(my_class)) {
ret = PTR_ERR(my_class);
goto del_cdev;
}
/* Create device */
device_create(my_class, NULL, dev_num, NULL, "mydev");
return 0;
del_cdev:
cdev_del(&my_cdev);
unregister_chrdev:
unregister_chrdev_region(dev_num, 1);
return ret;
}
static void __exit my_exit(void)
{
device_destroy(my_class, dev_num);
class_destroy(my_class);
cdev_del(&my_cdev);
unregister_chrdev_region(dev_num, 1);
}
module_init(my_init);
module_exit(my_exit);
MODULE_LICENSE("GPL");
Platform Device Pattern
#include <linux/platform_device.h>
static int my_probe(struct platform_device *pdev)
{
struct resource *res;
void __iomem *base;
/* Get resources */
res = platform_get_resource(pdev, IORESOURCE_MEM, 0);
if (!res)
return -ENODEV;
/* Map registers */
base = devm_ioremap_resource(&pdev->dev, res);
if (IS_ERR(base))
return PTR_ERR(base);
/* Store in device private data */
platform_set_drvdata(pdev, base);
return 0;
}
static int my_remove(struct platform_device *pdev)
{
/* Cleanup */
return 0;
}
static const struct of_device_id my_of_match[] = {
{ .compatible = "vendor,my-device" },
{ }
};
MODULE_DEVICE_TABLE(of, my_of_match);
static struct platform_driver my_driver = {
.probe = my_probe,
.remove = my_remove,
.driver = {
.name = "my-driver",
.of_match_table = my_of_match,
},
};
module_platform_driver(my_driver);
Interrupt Handler Pattern
#include <linux/interrupt.h>
static irqreturn_t my_interrupt(int irq, void *dev_id)
{
struct my_device *dev = dev_id;
u32 status;
/* Read interrupt status */
status = readl(dev->base + STATUS_REG);
if (!(status & MY_IRQ_FLAG))
return IRQ_NONE; /* Not our interrupt */
/* Clear interrupt */
writel(status, dev->base + STATUS_REG);
/* Handle interrupt - do minimal work */
/* Schedule bottom half if needed */
tasklet_schedule(&dev->tasklet);
return IRQ_HANDLED;
}
/* Bottom half (tasklet) */
static void my_tasklet_func(unsigned long data)
{
struct my_device *dev = (struct my_device *)data;
/* Do heavy work here */
}
/* Request IRQ */
ret = request_irq(irq, my_interrupt, IRQF_SHARED, "mydev", dev);
/* Free IRQ */
free_irq(irq, dev);
/* Threaded IRQ (for handlers that can sleep) */
ret = request_threaded_irq(irq, NULL, my_threaded_handler,
IRQF_ONESHOT, "mydev", dev);
Data Structures
Linked Lists
#include <linux/list.h>
struct my_node {
int data;
struct list_head list;
};
/* Define and initialize list head */
static LIST_HEAD(my_list);
/* Add entry */
struct my_node *node = kmalloc(sizeof(*node), GFP_KERNEL);
node->data = 42;
list_add(&node->list, &my_list); /* Add to head */
list_add_tail(&node->list, &my_list); /* Add to tail */
/* Iterate */
struct my_node *entry;
list_for_each_entry(entry, &my_list, list) {
printk(KERN_INFO "data: %d\n", entry->data);
}
/* Safe iteration (allows deletion) */
struct my_node *tmp;
list_for_each_entry_safe(entry, tmp, &my_list, list) {
if (entry->data == 42) {
list_del(&entry->list);
kfree(entry);
}
}
/* Check if empty */
if (list_empty(&my_list))
printk(KERN_INFO "List is empty\n");
Hash Tables
#include <linux/hashtable.h>
#define HASH_BITS 8
struct my_entry {
int key;
int value;
struct hlist_node hash;
};
/* Declare hash table */
static DEFINE_HASHTABLE(my_hash, HASH_BITS);
/* Initialize */
hash_init(my_hash);
/* Add entry */
struct my_entry *entry = kmalloc(sizeof(*entry), GFP_KERNEL);
entry->key = 123;
entry->value = 456;
hash_add(my_hash, &entry->hash, entry->key);
/* Find entry */
struct my_entry *found = NULL;
hash_for_each_possible(my_hash, entry, hash, key) {
if (entry->key == key) {
found = entry;
break;
}
}
/* Delete entry */
hash_del(&entry->hash);
/* Iterate all entries */
int bkt;
hash_for_each(my_hash, bkt, entry, hash) {
printk(KERN_INFO "key=%d value=%d\n", entry->key, entry->value);
}
Radix Tree
#include <linux/radix-tree.h>
static RADIX_TREE(my_tree, GFP_KERNEL);
/* Insert */
void *item = kmalloc(sizeof(struct my_data), GFP_KERNEL);
radix_tree_insert(&my_tree, index, item);
/* Lookup */
void *found = radix_tree_lookup(&my_tree, index);
/* Delete */
void *deleted = radix_tree_delete(&my_tree, index);
kfree(deleted);
/* Iterate */
struct radix_tree_iter iter;
void **slot;
radix_tree_for_each_slot(slot, &my_tree, &iter, start) {
void *item = radix_tree_deref_slot(slot);
/* Process item */
}
Red-Black Tree
#include <linux/rbtree.h>
struct my_node {
int key;
struct rb_node node;
};
static struct rb_root my_tree = RB_ROOT;
/* Insert */
int my_insert(struct rb_root *root, struct my_node *data)
{
struct rb_node **new = &(root->rb_node), *parent = NULL;
while (*new) {
struct my_node *this = container_of(*new, struct my_node, node);
parent = *new;
if (data->key < this->key)
new = &((*new)->rb_left);
else if (data->key > this->key)
new = &((*new)->rb_right);
else
return -EEXIST;
}
rb_link_node(&data->node, parent, new);
rb_insert_color(&data->node, root);
return 0;
}
/* Search */
struct my_node *my_search(struct rb_root *root, int key)
{
struct rb_node *node = root->rb_node;
while (node) {
struct my_node *data = container_of(node, struct my_node, node);
if (key < data->key)
node = node->rb_left;
else if (key > data->key)
node = node->rb_right;
else
return data;
}
return NULL;
}
/* Erase */
rb_erase(&node->node, &my_tree);
Kernel APIs
Workqueues
#include <linux/workqueue.h>
struct work_struct my_work;
/* Work function */
static void my_work_func(struct work_struct *work)
{
/* Do work in process context */
}
/* Initialize */
INIT_WORK(&my_work, my_work_func);
/* Schedule work */
schedule_work(&my_work);
/* Delayed work */
struct delayed_work my_delayed_work;
INIT_DELAYED_WORK(&my_delayed_work, my_work_func);
schedule_delayed_work(&my_delayed_work, msecs_to_jiffies(1000));
/* Cancel work */
cancel_work_sync(&my_work);
cancel_delayed_work_sync(&my_delayed_work);
Timers
#include <linux/timer.h>
struct timer_list my_timer;
/* Timer callback */
static void my_timer_callback(struct timer_list *t)
{
/* Timer expired */
printk(KERN_INFO "Timer expired\n");
/* Reschedule if needed */
mod_timer(&my_timer, jiffies + msecs_to_jiffies(1000));
}
/* Initialize and start timer */
timer_setup(&my_timer, my_timer_callback, 0);
mod_timer(&my_timer, jiffies + msecs_to_jiffies(1000));
/* Stop timer */
del_timer_sync(&my_timer);
/* High-resolution timers */
#include <linux/hrtimer.h>
struct hrtimer my_hrtimer;
static enum hrtimer_restart my_hrtimer_callback(struct hrtimer *timer)
{
/* Timer expired */
return HRTIMER_NORESTART; /* Or HRTIMER_RESTART */
}
hrtimer_init(&my_hrtimer, CLOCK_MONOTONIC, HRTIMER_MODE_REL);
my_hrtimer.function = my_hrtimer_callback;
hrtimer_start(&my_hrtimer, ms_to_ktime(1000), HRTIMER_MODE_REL);
Wait Queues
#include <linux/wait.h>
static DECLARE_WAIT_QUEUE_HEAD(my_wait_queue);
static int condition = 0;
/* Wait for condition */
wait_event(my_wait_queue, condition != 0);
/* Wait with timeout */
int ret = wait_event_timeout(my_wait_queue, condition != 0,
msecs_to_jiffies(5000));
/* Interruptible wait */
if (wait_event_interruptible(my_wait_queue, condition != 0))
return -ERESTARTSYS;
/* Wake up waiters */
condition = 1;
wake_up(&my_wait_queue); /* Wake one */
wake_up_all(&my_wait_queue); /* Wake all */
wake_up_interruptible(&my_wait_queue);
Kernel Threads
#include <linux/kthread.h>
static struct task_struct *my_thread;
static int my_thread_func(void *data)
{
while (!kthread_should_stop()) {
/* Do work */
/* Sleep */
msleep(1000);
/* Or wait for condition */
wait_event_interruptible(queue, condition || kthread_should_stop());
}
return 0;
}
/* Create and start thread */
my_thread = kthread_run(my_thread_func, NULL, "my_thread");
if (IS_ERR(my_thread))
return PTR_ERR(my_thread);
/* Stop thread */
kthread_stop(my_thread);
Debugging Patterns
Print Debugging
/* Use appropriate log level */
printk(KERN_EMERG "Emergency\n"); /* System unusable */
printk(KERN_ALERT "Alert\n"); /* Action must be taken */
printk(KERN_CRIT "Critical\n"); /* Critical conditions */
printk(KERN_ERR "Error\n"); /* Error conditions */
printk(KERN_WARNING "Warning\n"); /* Warning conditions */
printk(KERN_NOTICE "Notice\n"); /* Normal but significant */
printk(KERN_INFO "Info\n"); /* Informational */
printk(KERN_DEBUG "Debug\n"); /* Debug messages */
/* Modern API */
pr_emerg("Emergency\n");
pr_err("Error\n");
pr_info("Info\n");
pr_debug("Debug\n"); /* Only if DEBUG is defined */
/* Device-specific logging */
dev_err(&pdev->dev, "Device error\n");
dev_info(&pdev->dev, "Device info\n");
Dynamic Debug
/* Compile with CONFIG_DYNAMIC_DEBUG */
/* Use pr_debug or dev_dbg */
pr_debug("Debug message: value=%d\n", value);
dev_dbg(&dev->dev, "Device debug: %s\n", msg);
/* Enable at runtime */
/* echo 'file mydriver.c +p' > /sys/kernel/debug/dynamic_debug/control */
Assertions
/* BUG and WARN macros */
BUG_ON(bad_condition); /* Panic if true */
WARN_ON(warning_condition); /* Warning if true */
if (WARN_ON_ONCE(ptr == NULL))
return -EINVAL;
/* Better: return error instead of crashing */
if (WARN(bad_condition, "Something went wrong: %d\n", value))
return -EINVAL;
Tracing
#include <linux/trace_events.h>
/* Use ftrace */
trace_printk("Fast trace message: %d\n", value);
/* Define tracepoints */
#include <trace/events/mydriver.h>
TRACE_EVENT(my_event,
TP_PROTO(int value),
TP_ARGS(value),
TP_STRUCT__entry(
__field(int, value)
),
TP_fast_assign(
__entry->value = value;
),
TP_printk("value=%d", __entry->value)
);
/* Use tracepoint */
trace_my_event(42);
Best Practices
Resource Management
/* Use devm_* functions for automatic cleanup on error/remove */
void __iomem *base = devm_ioremap_resource(&pdev->dev, res);
int *ptr = devm_kmalloc(&pdev->dev, size, GFP_KERNEL);
int irq = devm_request_irq(&pdev->dev, irq_num, handler, flags, name, dev);
/* These are automatically freed when device is removed */
Copy to/from User Space
/* Always use copy_to_user/copy_from_user */
if (copy_to_user(user_buf, kernel_buf, count))
return -EFAULT;
if (copy_from_user(kernel_buf, user_buf, count))
return -EFAULT;
/* For single values */
int value;
if (get_user(value, (int __user *)arg))
return -EFAULT;
if (put_user(value, (int __user *)arg))
return -EFAULT;
/* Check access */
if (!access_ok(user_buf, count))
return -EFAULT;
Module Parameters
/* Define module parameters */
static int debug = 0;
module_param(debug, int, 0644);
MODULE_PARM_DESC(debug, "Enable debug mode");
static char *name = "default";
module_param(name, charp, 0644);
MODULE_PARM_DESC(name, "Device name");
/* Load module with parameters */
/* insmod mymodule.ko debug=1 name="custom" */
SMP Safety
/* Always consider SMP (multiprocessor) safety */
/* Use per-CPU variables for lock-free data */
DEFINE_PER_CPU(int, my_counter);
int val = get_cpu_var(my_counter);
val++;
put_cpu_var(my_counter);
/* Use proper locking */
/* Identify data that needs protection */
/* Choose appropriate lock type (spinlock vs mutex) */
/* Keep critical sections short */
/* Avoid nested locks (lock ordering) */
Power Management
/* Implement PM operations */
static int my_suspend(struct device *dev)
{
/* Save state, disable device */
return 0;
}
static int my_resume(struct device *dev)
{
/* Restore state, enable device */
return 0;
}
static const struct dev_pm_ops my_pm_ops = {
.suspend = my_suspend,
.resume = my_resume,
};
static struct platform_driver my_driver = {
.driver = {
.name = "my-driver",
.pm = &my_pm_ops,
},
};
Common Pitfalls
Don't Do This
/* DON'T use floating point in kernel */
// float x = 3.14; /* Wrong! */
/* DON'T use large stack allocations */
// char buffer[8192]; /* Too big for stack */
/* Use kmalloc instead */
/* DON'T sleep in atomic context */
spin_lock(&lock);
// msleep(100); /* Wrong! */
spin_unlock(&lock);
/* DON'T access user space directly */
// int *user_ptr;
// *user_ptr = 5; /* Wrong! Use copy_to_user */
/* DON'T ignore return values */
// kmalloc(size, GFP_KERNEL); /* Check for NULL! */
/* DON'T use unbounded loops */
// while (1) { } /* Use kthread_should_stop() */
Resources
- Kernel Documentation:
Documentation/in kernel source - Coding Style:
Documentation/process/coding-style.rst - API Documentation:
Documentation/core-api/ - Linux Kernel Development by Robert Love
- Linux Device Drivers by Corbet, Rubini, and Kroah-Hartman
- Understanding the Linux Kernel by Bovet and Cesati
Linux kernel development follows well-established patterns that promote consistency, safety, and performance. Understanding these patterns is essential for writing quality kernel code that integrates well with the rest of the kernel.
Linux Driver Development
Comprehensive guide to developing device drivers for the Linux kernel, covering the driver model, device types, and best practices.
Table of Contents
- Introduction
- Linux Driver Model
- Device Types
- Character Device Drivers
- Platform Drivers
- Bus Drivers
- Block Device Drivers
- Network Device Drivers
- Device Tree
- Power Management
- DMA
- Interrupts
- sysfs and Device Model
- Debugging
- Best Practices
Introduction
Linux device drivers are kernel modules that provide an interface between hardware devices and the kernel. They abstract hardware complexity and provide a uniform API for user space.
Driver Architecture
┌─────────────────────────────────────┐
│ User Space │
│ (Applications, Libraries) │
└─────────────────────────────────────┘
│ System Calls
┌─────────────────────────────────────┐
│ Kernel Space │
│ ┌───────────────────────────────┐ │
│ │ Virtual File System (VFS) │ │
│ └───────────────────────────────┘ │
│ │ │
│ ┌───────────────────────────────┐ │
│ │ Device Drivers │ │
│ │ - Character Drivers │ │
│ │ - Block Drivers │ │
│ │ - Network Drivers │ │
│ └───────────────────────────────┘ │
│ │ │
│ ┌───────────────────────────────┐ │
│ │ Bus Subsystems │ │
│ │ - PCI, USB, I2C, SPI, etc. │ │
│ └───────────────────────────────┘ │
└─────────────────────────────────────┘
│
┌─────────────────────────────────────┐
│ Hardware │
└─────────────────────────────────────┘
Driver Types
- Character Drivers: Sequential access (serial ports, keyboards)
- Block Drivers: Random access (hard drives, SSDs)
- Network Drivers: Network interfaces (Ethernet, WiFi)
Linux Driver Model
The Linux driver model provides a unified framework for device management.
Core Components
Device ←→ Driver ←→ Bus
↓ ↓ ↓
struct struct struct
device driver bus
Key Structures
#include <linux/device.h>
/* Device structure */
struct device {
struct device *parent;
struct device_private *p;
struct kobject kobj;
const char *init_name;
const struct device_type *type;
struct bus_type *bus;
struct device_driver *driver;
void *platform_data;
void *driver_data;
struct dev_pm_info power;
struct dev_pm_domain *pm_domain;
int numa_node;
u64 *dma_mask;
u64 coherent_dma_mask;
struct device_dma_parameters *dma_parms;
struct list_head dma_pools;
struct dma_coherent_mem *dma_mem;
struct dev_archdata archdata;
struct device_node *of_node;
struct fwnode_handle *fwnode;
dev_t devt;
u32 id;
spinlock_t devres_lock;
struct list_head devres_head;
};
/* Driver structure */
struct device_driver {
const char *name;
struct bus_type *bus;
struct module *owner;
const char *mod_name;
bool suppress_bind_attrs;
const struct of_device_id *of_match_table;
const struct acpi_device_id *acpi_match_table;
int (*probe) (struct device *dev);
int (*remove) (struct device *dev);
void (*shutdown) (struct device *dev);
int (*suspend) (struct device *dev, pm_message_t state);
int (*resume) (struct device *dev);
const struct attribute_group **groups;
const struct dev_pm_ops *pm;
struct driver_private *p;
};
/* Bus type structure */
struct bus_type {
const char *name;
const char *dev_name;
struct device *dev_root;
const struct attribute_group **bus_groups;
const struct attribute_group **dev_groups;
const struct attribute_group **drv_groups;
int (*match)(struct device *dev, struct device_driver *drv);
int (*uevent)(struct device *dev, struct kobj_uevent_env *env);
int (*probe)(struct device *dev);
int (*remove)(struct device *dev);
void (*shutdown)(struct device *dev);
int (*suspend)(struct device *dev, pm_message_t state);
int (*resume)(struct device *dev);
const struct dev_pm_ops *pm;
struct subsys_private *p;
};
Device Registration
/* Register a device */
int device_register(struct device *dev)
{
device_initialize(dev);
return device_add(dev);
}
/* Example: Create and register a device */
static int create_my_device(struct device *parent)
{
struct device *dev;
int ret;
dev = kzalloc(sizeof(*dev), GFP_KERNEL);
if (!dev)
return -ENOMEM;
dev->parent = parent;
dev->bus = &my_bus_type;
dev_set_name(dev, "mydevice%d", id);
ret = device_register(dev);
if (ret) {
put_device(dev);
return ret;
}
return 0;
}
/* Unregister device */
void device_unregister(struct device *dev)
{
device_del(dev);
put_device(dev);
}
Driver Registration
/* Register a driver */
int driver_register(struct device_driver *drv)
{
int ret;
ret = bus_add_driver(drv);
if (ret)
return ret;
ret = driver_add_groups(drv, drv->groups);
if (ret) {
bus_remove_driver(drv);
return ret;
}
return 0;
}
/* Example: Register a driver */
static struct device_driver my_driver = {
.name = "my_driver",
.bus = &my_bus_type,
.probe = my_probe,
.remove = my_remove,
.pm = &my_pm_ops,
};
static int __init my_driver_init(void)
{
return driver_register(&my_driver);
}
static void __exit my_driver_exit(void)
{
driver_unregister(&my_driver);
}
module_init(my_driver_init);
module_exit(my_driver_exit);
Matching Devices and Drivers
/* Bus match function */
static int my_bus_match(struct device *dev, struct device_driver *drv)
{
struct my_device *my_dev = to_my_device(dev);
struct my_driver *my_drv = to_my_driver(drv);
/* Match by name */
if (strcmp(dev_name(dev), drv->name) == 0)
return 1;
/* Match by compatible string (device tree) */
if (of_driver_match_device(dev, drv))
return 1;
return 0;
}
Device Types
Character Devices
Sequential access devices. Most common type.
#include <linux/cdev.h>
#include <linux/fs.h>
struct my_char_dev {
struct cdev cdev;
dev_t devt;
struct class *class;
struct device *device;
/* Device-specific data */
void __iomem *base;
struct mutex lock;
};
static int my_open(struct inode *inode, struct file *filp)
{
struct my_char_dev *dev;
dev = container_of(inode->i_cdev, struct my_char_dev, cdev);
filp->private_data = dev;
pr_info("Device opened\n");
return 0;
}
static int my_release(struct inode *inode, struct file *filp)
{
pr_info("Device closed\n");
return 0;
}
static ssize_t my_read(struct file *filp, char __user *buf,
size_t count, loff_t *f_pos)
{
struct my_char_dev *dev = filp->private_data;
char kbuf[256];
size_t len;
/* Read from hardware */
len = snprintf(kbuf, sizeof(kbuf), "Hello from device\n");
if (count < len)
len = count;
if (copy_to_user(buf, kbuf, len))
return -EFAULT;
*f_pos += len;
return len;
}
static ssize_t my_write(struct file *filp, const char __user *buf,
size_t count, loff_t *f_pos)
{
struct my_char_dev *dev = filp->private_data;
char kbuf[256];
if (count > sizeof(kbuf) - 1)
count = sizeof(kbuf) - 1;
if (copy_from_user(kbuf, buf, count))
return -EFAULT;
kbuf[count] = '\0';
pr_info("Received: %s\n", kbuf);
/* Write to hardware */
return count;
}
static long my_ioctl(struct file *filp, unsigned int cmd, unsigned long arg)
{
struct my_char_dev *dev = filp->private_data;
switch (cmd) {
case MY_IOCTL_RESET:
/* Reset device */
pr_info("Reset device\n");
break;
case MY_IOCTL_GET_STATUS:
/* Get device status */
if (copy_to_user((void __user *)arg, &dev->status,
sizeof(dev->status)))
return -EFAULT;
break;
default:
return -ENOTTY;
}
return 0;
}
static const struct file_operations my_fops = {
.owner = THIS_MODULE,
.open = my_open,
.release = my_release,
.read = my_read,
.write = my_write,
.unlocked_ioctl = my_ioctl,
};
Block Devices
Random access storage devices.
#include <linux/blkdev.h>
#include <linux/genhd.h>
struct my_block_dev {
spinlock_t lock;
struct request_queue *queue;
struct gendisk *gd;
u8 *data; /* Virtual disk storage */
size_t size; /* Size in bytes */
};
static void my_request(struct request_queue *q)
{
struct request *req;
struct my_block_dev *dev = q->queuedata;
while ((req = blk_fetch_request(q)) != NULL) {
sector_t sector = blk_rq_pos(req);
unsigned long offset = sector * KERNEL_SECTOR_SIZE;
size_t len = blk_rq_bytes(req);
if (offset + len > dev->size) {
pr_err("Beyond device size\n");
__blk_end_request_all(req, -EIO);
continue;
}
if (rq_data_dir(req) == WRITE) {
/* Write to virtual disk */
memcpy(dev->data + offset, bio_data(req->bio), len);
} else {
/* Read from virtual disk */
memcpy(bio_data(req->bio), dev->data + offset, len);
}
__blk_end_request_all(req, 0);
}
}
static int my_block_open(struct block_device *bdev, fmode_t mode)
{
pr_info("Block device opened\n");
return 0;
}
static void my_block_release(struct gendisk *gd, fmode_t mode)
{
pr_info("Block device released\n");
}
static const struct block_device_operations my_bdev_ops = {
.owner = THIS_MODULE,
.open = my_block_open,
.release = my_block_release,
};
static int create_block_device(struct my_block_dev *dev)
{
int ret;
/* Allocate request queue */
spin_lock_init(&dev->lock);
dev->queue = blk_init_queue(my_request, &dev->lock);
if (!dev->queue)
return -ENOMEM;
dev->queue->queuedata = dev;
/* Allocate gendisk */
dev->gd = alloc_disk(1);
if (!dev->gd) {
blk_cleanup_queue(dev->queue);
return -ENOMEM;
}
dev->gd->major = MY_MAJOR;
dev->gd->first_minor = 0;
dev->gd->fops = &my_bdev_ops;
dev->gd->queue = dev->queue;
dev->gd->private_data = dev;
snprintf(dev->gd->disk_name, 32, "myblock");
set_capacity(dev->gd, dev->size / KERNEL_SECTOR_SIZE);
add_disk(dev->gd);
return 0;
}
Network Devices
#include <linux/netdevice.h>
#include <linux/etherdevice.h>
struct my_net_priv {
struct net_device *dev;
struct napi_struct napi;
void __iomem *base;
spinlock_t lock;
};
static int my_net_open(struct net_device *dev)
{
struct my_net_priv *priv = netdev_priv(dev);
/* Enable hardware */
/* Request IRQ */
/* Enable NAPI */
napi_enable(&priv->napi);
netif_start_queue(dev);
pr_info("Network device opened\n");
return 0;
}
static int my_net_stop(struct net_device *dev)
{
struct my_net_priv *priv = netdev_priv(dev);
netif_stop_queue(dev);
napi_disable(&priv->napi);
/* Free IRQ */
/* Disable hardware */
pr_info("Network device closed\n");
return 0;
}
static netdev_tx_t my_net_start_xmit(struct sk_buff *skb,
struct net_device *dev)
{
struct my_net_priv *priv = netdev_priv(dev);
/* Transmit packet */
/* Write to hardware TX ring */
dev->stats.tx_packets++;
dev->stats.tx_bytes += skb->len;
dev_kfree_skb(skb);
return NETDEV_TX_OK;
}
static int my_net_poll(struct napi_struct *napi, int budget)
{
struct my_net_priv *priv = container_of(napi, struct my_net_priv, napi);
struct net_device *dev = priv->dev;
int work_done = 0;
struct sk_buff *skb;
/* Process RX packets */
while (work_done < budget) {
/* Get packet from hardware */
skb = my_get_rx_packet(priv);
if (!skb)
break;
skb->dev = dev;
skb->protocol = eth_type_trans(skb, dev);
netif_receive_skb(skb);
dev->stats.rx_packets++;
dev->stats.rx_bytes += skb->len;
work_done++;
}
if (work_done < budget) {
napi_complete(napi);
/* Re-enable interrupts */
}
return work_done;
}
static const struct net_device_ops my_netdev_ops = {
.ndo_open = my_net_open,
.ndo_stop = my_net_stop,
.ndo_start_xmit = my_net_start_xmit,
};
static int create_net_device(struct device *parent)
{
struct net_device *dev;
struct my_net_priv *priv;
int ret;
dev = alloc_etherdev(sizeof(*priv));
if (!dev)
return -ENOMEM;
priv = netdev_priv(dev);
priv->dev = dev;
dev->netdev_ops = &my_netdev_ops;
dev->watchdog_timeo = 5 * HZ;
/* Set MAC address */
eth_hw_addr_random(dev);
/* Setup NAPI */
netif_napi_add(dev, &priv->napi, my_net_poll, 64);
SET_NETDEV_DEV(dev, parent);
ret = register_netdev(dev);
if (ret) {
free_netdev(dev);
return ret;
}
return 0;
}
Character Device Drivers
Complete example with multiple features.
#include <linux/module.h>
#include <linux/fs.h>
#include <linux/cdev.h>
#include <linux/device.h>
#include <linux/uaccess.h>
#define DEVICE_NAME "mychardev"
#define CLASS_NAME "myclass"
static int major_number;
static struct class *my_class;
static struct device *my_device;
static struct cdev my_cdev;
static char message[256] = "Hello from driver";
static short message_len;
static int times_opened = 0;
static int dev_open(struct inode *inode, struct file *file)
{
times_opened++;
pr_info("Device opened %d times\n", times_opened);
return 0;
}
static int dev_release(struct inode *inode, struct file *file)
{
pr_info("Device closed\n");
return 0;
}
static ssize_t dev_read(struct file *file, char __user *buffer,
size_t len, loff_t *offset)
{
int bytes_to_read;
if (*offset >= message_len)
return 0;
bytes_to_read = min(len, (size_t)(message_len - *offset));
if (copy_to_user(buffer, message + *offset, bytes_to_read))
return -EFAULT;
*offset += bytes_to_read;
pr_info("Sent %d characters to user\n", bytes_to_read);
return bytes_to_read;
}
static ssize_t dev_write(struct file *file, const char __user *buffer,
size_t len, loff_t *offset)
{
size_t bytes_to_write = min(len, sizeof(message) - 1);
if (copy_from_user(message, buffer, bytes_to_write))
return -EFAULT;
message[bytes_to_write] = '\0';
message_len = bytes_to_write;
pr_info("Received %zu characters from user\n", bytes_to_write);
return bytes_to_write;
}
static struct file_operations fops = {
.owner = THIS_MODULE,
.open = dev_open,
.release = dev_release,
.read = dev_read,
.write = dev_write,
};
static int __init chardev_init(void)
{
int ret;
dev_t dev;
/* Allocate major number */
ret = alloc_chrdev_region(&dev, 0, 1, DEVICE_NAME);
if (ret < 0) {
pr_err("Failed to allocate major number\n");
return ret;
}
major_number = MAJOR(dev);
pr_info("Registered with major number %d\n", major_number);
/* Initialize cdev */
cdev_init(&my_cdev, &fops);
my_cdev.owner = THIS_MODULE;
/* Add cdev */
ret = cdev_add(&my_cdev, dev, 1);
if (ret < 0) {
unregister_chrdev_region(dev, 1);
pr_err("Failed to add cdev\n");
return ret;
}
/* Create class */
my_class = class_create(THIS_MODULE, CLASS_NAME);
if (IS_ERR(my_class)) {
cdev_del(&my_cdev);
unregister_chrdev_region(dev, 1);
pr_err("Failed to create class\n");
return PTR_ERR(my_class);
}
/* Create device */
my_device = device_create(my_class, NULL, dev, NULL, DEVICE_NAME);
if (IS_ERR(my_device)) {
class_destroy(my_class);
cdev_del(&my_cdev);
unregister_chrdev_region(dev, 1);
pr_err("Failed to create device\n");
return PTR_ERR(my_device);
}
message_len = strlen(message);
pr_info("Character device driver loaded\n");
return 0;
}
static void __exit chardev_exit(void)
{
dev_t dev = MKDEV(major_number, 0);
device_destroy(my_class, dev);
class_destroy(my_class);
cdev_del(&my_cdev);
unregister_chrdev_region(dev, 1);
pr_info("Character device driver unloaded\n");
}
module_init(chardev_init);
module_exit(chardev_exit);
MODULE_LICENSE("GPL");
MODULE_AUTHOR("Driver Developer");
MODULE_DESCRIPTION("Simple character device driver");
Platform Drivers
Platform drivers are for devices that are not discoverable (embedded SoCs).
#include <linux/platform_device.h>
#include <linux/mod_devicetable.h>
#include <linux/io.h>
#include <linux/of.h>
struct my_platform_dev {
struct device *dev;
void __iomem *base;
struct resource *res;
int irq;
};
static int my_platform_probe(struct platform_device *pdev)
{
struct my_platform_dev *priv;
struct resource *res;
int ret;
pr_info("Platform driver probe\n");
priv = devm_kzalloc(&pdev->dev, sizeof(*priv), GFP_KERNEL);
if (!priv)
return -ENOMEM;
priv->dev = &pdev->dev;
/* Get memory resource */
res = platform_get_resource(pdev, IORESOURCE_MEM, 0);
if (!res) {
dev_err(&pdev->dev, "No memory resource\n");
return -ENODEV;
}
/* Map registers */
priv->base = devm_ioremap_resource(&pdev->dev, res);
if (IS_ERR(priv->base))
return PTR_ERR(priv->base);
/* Get IRQ */
priv->irq = platform_get_irq(pdev, 0);
if (priv->irq < 0) {
dev_err(&pdev->dev, "No IRQ resource\n");
return priv->irq;
}
/* Request IRQ */
ret = devm_request_irq(&pdev->dev, priv->irq, my_irq_handler,
IRQF_SHARED, dev_name(&pdev->dev), priv);
if (ret) {
dev_err(&pdev->dev, "Failed to request IRQ\n");
return ret;
}
/* Store private data */
platform_set_drvdata(pdev, priv);
/* Initialize hardware */
writel(0x1, priv->base + CTRL_REG);
dev_info(&pdev->dev, "Device initialized\n");
return 0;
}
static int my_platform_remove(struct platform_device *pdev)
{
struct my_platform_dev *priv = platform_get_drvdata(pdev);
/* Shutdown hardware */
writel(0x0, priv->base + CTRL_REG);
dev_info(&pdev->dev, "Device removed\n");
return 0;
}
/* Device tree match table */
static const struct of_device_id my_of_match[] = {
{ .compatible = "vendor,my-device" },
{ .compatible = "vendor,my-device-v2" },
{ }
};
MODULE_DEVICE_TABLE(of, my_of_match);
/* Platform device ID table (for non-DT systems) */
static const struct platform_device_id my_platform_ids[] = {
{ .name = "my-device" },
{ }
};
MODULE_DEVICE_TABLE(platform, my_platform_ids);
static struct platform_driver my_platform_driver = {
.probe = my_platform_probe,
.remove = my_platform_remove,
.driver = {
.name = "my-device",
.of_match_table = my_of_match,
},
.id_table = my_platform_ids,
};
module_platform_driver(my_platform_driver);
MODULE_LICENSE("GPL");
MODULE_DESCRIPTION("Platform device driver");
Bus Drivers
I2C Driver
#include <linux/i2c.h>
struct my_i2c_dev {
struct i2c_client *client;
struct device *dev;
};
static int my_i2c_probe(struct i2c_client *client,
const struct i2c_device_id *id)
{
struct my_i2c_dev *priv;
u8 buf[2];
int ret;
dev_info(&client->dev, "I2C device probed\n");
priv = devm_kzalloc(&client->dev, sizeof(*priv), GFP_KERNEL);
if (!priv)
return -ENOMEM;
priv->client = client;
priv->dev = &client->dev;
i2c_set_clientdata(client, priv);
/* Read device ID */
ret = i2c_smbus_read_byte_data(client, REG_ID);
if (ret < 0) {
dev_err(&client->dev, "Failed to read device ID\n");
return ret;
}
dev_info(&client->dev, "Device ID: 0x%02x\n", ret);
/* Write configuration */
buf[0] = REG_CONFIG;
buf[1] = 0x80;
ret = i2c_master_send(client, buf, 2);
if (ret < 0) {
dev_err(&client->dev, "Failed to write config\n");
return ret;
}
return 0;
}
static int my_i2c_remove(struct i2c_client *client)
{
dev_info(&client->dev, "I2C device removed\n");
return 0;
}
static const struct i2c_device_id my_i2c_ids[] = {
{ "my-i2c-device", 0 },
{ }
};
MODULE_DEVICE_TABLE(i2c, my_i2c_ids);
static const struct of_device_id my_i2c_of_match[] = {
{ .compatible = "vendor,my-i2c-device" },
{ }
};
MODULE_DEVICE_TABLE(of, my_i2c_of_match);
static struct i2c_driver my_i2c_driver = {
.driver = {
.name = "my-i2c-device",
.of_match_table = my_i2c_of_match,
},
.probe = my_i2c_probe,
.remove = my_i2c_remove,
.id_table = my_i2c_ids,
};
module_i2c_driver(my_i2c_driver);
SPI Driver
#include <linux/spi/spi.h>
struct my_spi_dev {
struct spi_device *spi;
struct device *dev;
};
static int my_spi_probe(struct spi_device *spi)
{
struct my_spi_dev *priv;
u8 tx_buf[2], rx_buf[2];
int ret;
dev_info(&spi->dev, "SPI device probed\n");
priv = devm_kzalloc(&spi->dev, sizeof(*priv), GFP_KERNEL);
if (!priv)
return -ENOMEM;
priv->spi = spi;
priv->dev = &spi->dev;
spi_set_drvdata(spi, priv);
/* Configure SPI mode and speed */
spi->mode = SPI_MODE_0;
spi->max_speed_hz = 1000000;
spi->bits_per_word = 8;
ret = spi_setup(spi);
if (ret < 0) {
dev_err(&spi->dev, "Failed to setup SPI\n");
return ret;
}
/* Read register */
tx_buf[0] = READ_CMD | REG_ID;
tx_buf[1] = 0x00;
ret = spi_write_then_read(spi, tx_buf, 1, rx_buf, 1);
if (ret < 0) {
dev_err(&spi->dev, "Failed to read register\n");
return ret;
}
dev_info(&spi->dev, "Device ID: 0x%02x\n", rx_buf[0]);
return 0;
}
static int my_spi_remove(struct spi_device *spi)
{
dev_info(&spi->dev, "SPI device removed\n");
return 0;
}
static const struct of_device_id my_spi_of_match[] = {
{ .compatible = "vendor,my-spi-device" },
{ }
};
MODULE_DEVICE_TABLE(of, my_spi_of_match);
static const struct spi_device_id my_spi_ids[] = {
{ "my-spi-device", 0 },
{ }
};
MODULE_DEVICE_TABLE(spi, my_spi_ids);
static struct spi_driver my_spi_driver = {
.driver = {
.name = "my-spi-device",
.of_match_table = my_spi_of_match,
},
.probe = my_spi_probe,
.remove = my_spi_remove,
.id_table = my_spi_ids,
};
module_spi_driver(my_spi_driver);
USB Driver
#include <linux/usb.h>
struct my_usb_dev {
struct usb_device *udev;
struct usb_interface *interface;
struct urb *int_in_urb;
unsigned char *int_in_buffer;
};
static void my_int_callback(struct urb *urb)
{
struct my_usb_dev *dev = urb->context;
int status = urb->status;
switch (status) {
case 0:
/* Success */
dev_info(&dev->interface->dev, "Data: %*ph\n",
urb->actual_length, dev->int_in_buffer);
break;
case -ECONNRESET:
case -ENOENT:
case -ESHUTDOWN:
/* URB killed */
return;
default:
dev_err(&dev->interface->dev, "URB error: %d\n", status);
break;
}
/* Resubmit URB */
usb_submit_urb(urb, GFP_ATOMIC);
}
static int my_usb_probe(struct usb_interface *interface,
const struct usb_device_id *id)
{
struct my_usb_dev *dev;
struct usb_host_interface *iface_desc;
struct usb_endpoint_descriptor *endpoint;
int ret;
dev_info(&interface->dev, "USB device probed\n");
dev = kzalloc(sizeof(*dev), GFP_KERNEL);
if (!dev)
return -ENOMEM;
dev->udev = usb_get_dev(interface_to_usbdev(interface));
dev->interface = interface;
/* Get endpoint descriptors */
iface_desc = interface->cur_altsetting;
for (int i = 0; i < iface_desc->desc.bNumEndpoints; i++) {
endpoint = &iface_desc->endpoint[i].desc;
if (usb_endpoint_is_int_in(endpoint)) {
/* Found interrupt IN endpoint */
dev->int_in_buffer = kmalloc(
le16_to_cpu(endpoint->wMaxPacketSize),
GFP_KERNEL);
if (!dev->int_in_buffer) {
ret = -ENOMEM;
goto error;
}
dev->int_in_urb = usb_alloc_urb(0, GFP_KERNEL);
if (!dev->int_in_urb) {
ret = -ENOMEM;
goto error;
}
usb_fill_int_urb(dev->int_in_urb, dev->udev,
usb_rcvintpipe(dev->udev, endpoint->bEndpointAddress),
dev->int_in_buffer,
le16_to_cpu(endpoint->wMaxPacketSize),
my_int_callback,
dev,
endpoint->bInterval);
/* Submit URB */
ret = usb_submit_urb(dev->int_in_urb, GFP_KERNEL);
if (ret) {
dev_err(&interface->dev, "Failed to submit URB\n");
goto error;
}
}
}
usb_set_intfdata(interface, dev);
return 0;
error:
if (dev->int_in_urb)
usb_free_urb(dev->int_in_urb);
kfree(dev->int_in_buffer);
usb_put_dev(dev->udev);
kfree(dev);
return ret;
}
static void my_usb_disconnect(struct usb_interface *interface)
{
struct my_usb_dev *dev;
dev = usb_get_intfdata(interface);
usb_set_intfdata(interface, NULL);
if (dev->int_in_urb) {
usb_kill_urb(dev->int_in_urb);
usb_free_urb(dev->int_in_urb);
}
kfree(dev->int_in_buffer);
usb_put_dev(dev->udev);
kfree(dev);
dev_info(&interface->dev, "USB device disconnected\n");
}
static const struct usb_device_id my_usb_table[] = {
{ USB_DEVICE(VENDOR_ID, PRODUCT_ID) },
{ }
};
MODULE_DEVICE_TABLE(usb, my_usb_table);
static struct usb_driver my_usb_driver = {
.name = "my-usb-device",
.probe = my_usb_probe,
.disconnect = my_usb_disconnect,
.id_table = my_usb_table,
};
module_usb_driver(my_usb_driver);
Block Device Drivers
(See earlier section for complete example)
Modern Block Layer (blk-mq)
#include <linux/blk-mq.h>
struct my_blk_dev {
struct blk_mq_tag_set tag_set;
struct request_queue *queue;
struct gendisk *disk;
void *data;
size_t size;
};
static blk_status_t my_queue_rq(struct blk_mq_hw_ctx *hctx,
const struct blk_mq_queue_data *bd)
{
struct request *rq = bd->rq;
struct my_blk_dev *dev = rq->q->queuedata;
struct bio_vec bvec;
struct req_iterator iter;
sector_t pos = blk_rq_pos(rq);
void *buffer;
unsigned long offset = pos * SECTOR_SIZE;
blk_mq_start_request(rq);
rq_for_each_segment(bvec, rq, iter) {
buffer = page_address(bvec.bv_page) + bvec.bv_offset;
if (rq_data_dir(rq) == WRITE)
memcpy(dev->data + offset, buffer, bvec.bv_len);
else
memcpy(buffer, dev->data + offset, bvec.bv_len);
offset += bvec.bv_len;
}
blk_mq_end_request(rq, BLK_STS_OK);
return BLK_STS_OK;
}
static const struct blk_mq_ops my_mq_ops = {
.queue_rq = my_queue_rq,
};
static int create_blkmq_device(struct my_blk_dev *dev)
{
int ret;
/* Initialize tag set */
memset(&dev->tag_set, 0, sizeof(dev->tag_set));
dev->tag_set.ops = &my_mq_ops;
dev->tag_set.nr_hw_queues = 1;
dev->tag_set.queue_depth = 128;
dev->tag_set.numa_node = NUMA_NO_NODE;
dev->tag_set.cmd_size = 0;
dev->tag_set.flags = BLK_MQ_F_SHOULD_MERGE;
dev->tag_set.driver_data = dev;
ret = blk_mq_alloc_tag_set(&dev->tag_set);
if (ret)
return ret;
/* Allocate queue */
dev->queue = blk_mq_init_queue(&dev->tag_set);
if (IS_ERR(dev->queue)) {
blk_mq_free_tag_set(&dev->tag_set);
return PTR_ERR(dev->queue);
}
dev->queue->queuedata = dev;
/* Allocate disk */
dev->disk = alloc_disk(1);
if (!dev->disk) {
blk_cleanup_queue(dev->queue);
blk_mq_free_tag_set(&dev->tag_set);
return -ENOMEM;
}
dev->disk->major = MY_MAJOR;
dev->disk->first_minor = 0;
dev->disk->fops = &my_bdev_ops;
dev->disk->queue = dev->queue;
dev->disk->private_data = dev;
snprintf(dev->disk->disk_name, 32, "myblkmq");
set_capacity(dev->disk, dev->size / SECTOR_SIZE);
add_disk(dev->disk);
return 0;
}
Network Device Drivers
(See earlier section for complete example)
Device Tree
Device tree describes hardware topology for non-discoverable devices.
Device Tree Syntax
/* my-device.dts */
/dts-v1/;
/ {
compatible = "vendor,my-board";
#address-cells = <1>;
#size-cells = <1>;
my_device: my-device@40000000 {
compatible = "vendor,my-device";
reg = <0x40000000 0x1000>;
interrupts = <0 25 4>;
clocks = <&clk_peripheral>;
clock-names = "peripheral";
status = "okay";
/* Custom properties */
vendor,feature-enable;
vendor,threshold = <100>;
vendor,string-prop = "value";
};
i2c@40005000 {
compatible = "vendor,i2c";
reg = <0x40005000 0x1000>;
#address-cells = <1>;
#size-cells = <0>;
sensor@48 {
compatible = "vendor,temperature-sensor";
reg = <0x48>;
};
};
};
Parsing Device Tree in Driver
#include <linux/of.h>
#include <linux/of_device.h>
#include <linux/of_irq.h>
static int my_probe(struct platform_device *pdev)
{
struct device *dev = &pdev->dev;
struct device_node *np = dev->of_node;
u32 threshold;
const char *string_prop;
int ret;
/* Check compatible string */
if (!of_device_is_compatible(np, "vendor,my-device"))
return -ENODEV;
/* Read u32 property */
ret = of_property_read_u32(np, "vendor,threshold", &threshold);
if (ret) {
dev_err(dev, "Failed to read threshold\n");
return ret;
}
dev_info(dev, "Threshold: %u\n", threshold);
/* Read string property */
ret = of_property_read_string(np, "vendor,string-prop", &string_prop);
if (ret == 0)
dev_info(dev, "String property: %s\n", string_prop);
/* Check boolean property */
if (of_property_read_bool(np, "vendor,feature-enable"))
dev_info(dev, "Feature enabled\n");
/* Get resource from reg property */
res = platform_get_resource(pdev, IORESOURCE_MEM, 0);
/* Get IRQ */
irq = irq_of_parse_and_map(np, 0);
/* Get clock */
clk = devm_clk_get(dev, "peripheral");
if (IS_ERR(clk))
return PTR_ERR(clk);
/* Get regulator */
regulator = devm_regulator_get(dev, "vdd");
return 0;
}
Power Management
#include <linux/pm.h>
#include <linux/pm_runtime.h>
/* System suspend/resume */
static int my_suspend(struct device *dev)
{
struct my_dev *priv = dev_get_drvdata(dev);
dev_info(dev, "Suspending\n");
/* Save state */
priv->saved_state = readl(priv->base + STATE_REG);
/* Disable device */
writel(0, priv->base + CTRL_REG);
/* Gate clock */
clk_disable_unprepare(priv->clk);
return 0;
}
static int my_resume(struct device *dev)
{
struct my_dev *priv = dev_get_drvdata(dev);
int ret;
dev_info(dev, "Resuming\n");
/* Ungate clock */
ret = clk_prepare_enable(priv->clk);
if (ret)
return ret;
/* Restore state */
writel(priv->saved_state, priv->base + STATE_REG);
/* Enable device */
writel(1, priv->base + CTRL_REG);
return 0;
}
/* Runtime PM */
static int my_runtime_suspend(struct device *dev)
{
struct my_dev *priv = dev_get_drvdata(dev);
dev_dbg(dev, "Runtime suspend\n");
clk_disable_unprepare(priv->clk);
return 0;
}
static int my_runtime_resume(struct device *dev)
{
struct my_dev *priv = dev_get_drvdata(dev);
int ret;
dev_dbg(dev, "Runtime resume\n");
ret = clk_prepare_enable(priv->clk);
if (ret)
return ret;
return 0;
}
static const struct dev_pm_ops my_pm_ops = {
SET_SYSTEM_SLEEP_PM_OPS(my_suspend, my_resume)
SET_RUNTIME_PM_OPS(my_runtime_suspend, my_runtime_resume, NULL)
};
/* Using runtime PM */
static int my_do_something(struct my_dev *priv)
{
int ret;
/* Get PM reference (resume device if suspended) */
ret = pm_runtime_get_sync(priv->dev);
if (ret < 0) {
pm_runtime_put_noidle(priv->dev);
return ret;
}
/* Do work */
writel(0x1, priv->base + CMD_REG);
/* Release PM reference */
pm_runtime_mark_last_busy(priv->dev);
pm_runtime_put_autosuspend(priv->dev);
return 0;
}
DMA
#include <linux/dma-mapping.h>
struct my_dma_dev {
struct device *dev;
dma_addr_t dma_handle;
void *cpu_addr;
size_t size;
};
/* Coherent (consistent) DMA mapping */
static int setup_coherent_dma(struct my_dma_dev *priv)
{
priv->size = 4096;
priv->cpu_addr = dma_alloc_coherent(priv->dev, priv->size,
&priv->dma_handle, GFP_KERNEL);
if (!priv->cpu_addr)
return -ENOMEM;
pr_info("DMA buffer: cpu=%p dma=%pad\n",
priv->cpu_addr, &priv->dma_handle);
/* Write data to DMA buffer */
memset(priv->cpu_addr, 0xAA, priv->size);
/* Program hardware with DMA address */
writel(priv->dma_handle, priv->base + DMA_ADDR_REG);
writel(priv->size, priv->base + DMA_SIZE_REG);
writel(DMA_START, priv->base + DMA_CTRL_REG);
return 0;
}
static void cleanup_coherent_dma(struct my_dma_dev *priv)
{
if (priv->cpu_addr) {
dma_free_coherent(priv->dev, priv->size,
priv->cpu_addr, priv->dma_handle);
priv->cpu_addr = NULL;
}
}
/* Streaming DMA mapping */
static int do_streaming_dma_tx(struct my_dma_dev *priv, void *buffer,
size_t len)
{
dma_addr_t dma_addr;
/* Map buffer for DMA */
dma_addr = dma_map_single(priv->dev, buffer, len, DMA_TO_DEVICE);
if (dma_mapping_error(priv->dev, dma_addr))
return -ENOMEM;
/* Program hardware */
writel(dma_addr, priv->base + DMA_ADDR_REG);
writel(len, priv->base + DMA_SIZE_REG);
writel(DMA_START, priv->base + DMA_CTRL_REG);
/* Wait for DMA completion (in real driver, use interrupt) */
/* Unmap buffer */
dma_unmap_single(priv->dev, dma_addr, len, DMA_TO_DEVICE);
return 0;
}
/* Scatter-gather DMA */
static int do_sg_dma(struct my_dma_dev *priv, struct scatterlist *sgl,
int nents)
{
int mapped_nents;
struct scatterlist *sg;
int i;
/* Map scatter-gather list */
mapped_nents = dma_map_sg(priv->dev, sgl, nents, DMA_TO_DEVICE);
if (!mapped_nents)
return -ENOMEM;
/* Program hardware with each SG entry */
for_each_sg(sgl, sg, mapped_nents, i) {
writel(sg_dma_address(sg),
priv->base + DMA_SG_ADDR_REG(i));
writel(sg_dma_len(sg),
priv->base + DMA_SG_LEN_REG(i));
}
writel(mapped_nents, priv->base + DMA_SG_COUNT_REG);
writel(DMA_SG_START, priv->base + DMA_CTRL_REG);
/* Wait for completion */
/* Unmap */
dma_unmap_sg(priv->dev, sgl, nents, DMA_TO_DEVICE);
return 0;
}
/* Set DMA mask */
static int setup_dma(struct device *dev)
{
int ret;
/* Try 64-bit DMA */
ret = dma_set_mask_and_coherent(dev, DMA_BIT_MASK(64));
if (ret) {
/* Fall back to 32-bit */
ret = dma_set_mask_and_coherent(dev, DMA_BIT_MASK(32));
if (ret) {
dev_err(dev, "No suitable DMA available\n");
return ret;
}
}
return 0;
}
Interrupts
#include <linux/interrupt.h>
/* Interrupt handler (top half) */
static irqreturn_t my_irq_handler(int irq, void *dev_id)
{
struct my_dev *priv = dev_id;
u32 status;
/* Read interrupt status */
status = readl(priv->base + INT_STATUS_REG);
if (!(status & MY_INT_MASK))
return IRQ_NONE; /* Not our interrupt */
/* Clear interrupt */
writel(status, priv->base + INT_STATUS_REG);
/* Minimal processing */
if (status & INT_ERROR)
priv->errors++;
/* Schedule bottom half */
schedule_work(&priv->work);
/* Or */
tasklet_schedule(&priv->tasklet);
return IRQ_HANDLED;
}
/* Bottom half (workqueue) */
static void my_work_func(struct work_struct *work)
{
struct my_dev *priv = container_of(work, struct my_dev, work);
/* Heavy processing that can sleep */
mutex_lock(&priv->lock);
/* Process data */
mutex_unlock(&priv->lock);
}
/* Bottom half (tasklet) */
static void my_tasklet_func(unsigned long data)
{
struct my_dev *priv = (struct my_dev *)data;
/* Processing that cannot sleep */
spin_lock(&priv->lock);
/* Process data */
spin_unlock(&priv->lock);
}
/* Threaded IRQ handler */
static irqreturn_t my_threaded_irq(int irq, void *dev_id)
{
struct my_dev *priv = dev_id;
/* This runs in a kernel thread, can sleep */
mutex_lock(&priv->lock);
/* Heavy processing */
mutex_unlock(&priv->lock);
return IRQ_HANDLED;
}
/* Setup interrupts */
static int setup_interrupts(struct my_dev *priv)
{
int ret;
/* Regular IRQ */
ret = devm_request_irq(priv->dev, priv->irq, my_irq_handler,
IRQF_SHARED, "my-device", priv);
if (ret) {
dev_err(priv->dev, "Failed to request IRQ\n");
return ret;
}
/* Threaded IRQ */
ret = devm_request_threaded_irq(priv->dev, priv->irq,
NULL, my_threaded_irq,
IRQF_ONESHOT, "my-device", priv);
if (ret) {
dev_err(priv->dev, "Failed to request threaded IRQ\n");
return ret;
}
/* Initialize work */
INIT_WORK(&priv->work, my_work_func);
/* Initialize tasklet */
tasklet_init(&priv->tasklet, my_tasklet_func, (unsigned long)priv);
return 0;
}
sysfs and Device Model
#include <linux/sysfs.h>
/* sysfs attribute */
static ssize_t threshold_show(struct device *dev,
struct device_attribute *attr,
char *buf)
{
struct my_dev *priv = dev_get_drvdata(dev);
return sprintf(buf, "%u\n", priv->threshold);
}
static ssize_t threshold_store(struct device *dev,
struct device_attribute *attr,
const char *buf, size_t count)
{
struct my_dev *priv = dev_get_drvdata(dev);
unsigned int val;
int ret;
ret = kstrtouint(buf, 0, &val);
if (ret)
return ret;
if (val > MAX_THRESHOLD)
return -EINVAL;
priv->threshold = val;
/* Update hardware */
writel(val, priv->base + THRESHOLD_REG);
return count;
}
static DEVICE_ATTR_RW(threshold);
/* Binary attribute (for large data) */
static ssize_t firmware_read(struct file *filp, struct kobject *kobj,
struct bin_attribute *attr,
char *buf, loff_t pos, size_t count)
{
struct device *dev = kobj_to_dev(kobj);
struct my_dev *priv = dev_get_drvdata(dev);
if (pos >= priv->firmware_size)
return 0;
if (pos + count > priv->firmware_size)
count = priv->firmware_size - pos;
memcpy(buf, priv->firmware + pos, count);
return count;
}
static BIN_ATTR_RO(firmware, 0);
/* Attribute group */
static struct attribute *my_attrs[] = {
&dev_attr_threshold.attr,
NULL,
};
static struct bin_attribute *my_bin_attrs[] = {
&bin_attr_firmware,
NULL,
};
static const struct attribute_group my_attr_group = {
.attrs = my_attrs,
.bin_attrs = my_bin_attrs,
};
/* Register attributes */
static int register_sysfs(struct my_dev *priv)
{
return sysfs_create_group(&priv->dev->kobj, &my_attr_group);
}
static void unregister_sysfs(struct my_dev *priv)
{
sysfs_remove_group(&priv->dev->kobj, &my_attr_group);
}
/* Alternative: Use device attribute groups directly */
static const struct attribute_group *my_attr_groups[] = {
&my_attr_group,
NULL,
};
/* Set in driver structure */
static struct device_driver my_driver = {
.groups = my_attr_groups,
};
Debugging
printk and dev_* Functions
/* Use appropriate log level */
pr_emerg("System is unusable\n");
pr_alert("Action must be taken immediately\n");
pr_crit("Critical conditions\n");
pr_err("Error conditions\n");
pr_warn("Warning conditions\n");
pr_notice("Normal but significant\n");
pr_info("Informational\n");
pr_debug("Debug-level messages\n");
/* Device-specific logging (preferred) */
dev_err(dev, "Device error: %d\n", err);
dev_warn(dev, "Device warning\n");
dev_info(dev, "Device information\n");
dev_dbg(dev, "Device debug\n");
/* Rate limited logging */
dev_err_ratelimited(dev, "This might happen often\n");
dev_warn_once(dev, "Only print once\n");
debugfs
#include <linux/debugfs.h>
struct my_dev {
struct dentry *debugfs_dir;
u32 debug_value;
};
static int register_debugfs(struct my_dev *priv)
{
priv->debugfs_dir = debugfs_create_dir("my-device", NULL);
if (!priv->debugfs_dir)
return -ENOMEM;
/* Create files */
debugfs_create_u32("debug_value", 0644, priv->debugfs_dir,
&priv->debug_value);
debugfs_create_file("registers", 0444, priv->debugfs_dir,
priv, ®isters_fops);
return 0;
}
static void unregister_debugfs(struct my_dev *priv)
{
debugfs_remove_recursive(priv->debugfs_dir);
}
/* Custom debugfs file operations */
static int registers_show(struct seq_file *s, void *unused)
{
struct my_dev *priv = s->private;
seq_printf(s, "CTRL: 0x%08x\n", readl(priv->base + CTRL_REG));
seq_printf(s, "STATUS: 0x%08x\n", readl(priv->base + STATUS_REG));
seq_printf(s, "DATA: 0x%08x\n", readl(priv->base + DATA_REG));
return 0;
}
static int registers_open(struct inode *inode, struct file *file)
{
return single_open(file, registers_show, inode->i_private);
}
static const struct file_operations registers_fops = {
.open = registers_open,
.read = seq_read,
.llseek = seq_lseek,
.release = single_release,
};
Tracing
/* Use trace_printk for fast debugging */
trace_printk("Fast trace: value=%d\n", value);
/* Define tracepoints */
#include <trace/events/my_driver.h>
TRACE_EVENT(my_event,
TP_PROTO(int value, const char *msg),
TP_ARGS(value, msg),
TP_STRUCT__entry(
__field(int, value)
__string(msg, msg)
),
TP_fast_assign(
__entry->value = value;
__assign_str(msg, msg);
),
TP_printk("value=%d msg=%s", __entry->value, __get_str(msg))
);
/* Use tracepoint */
trace_my_event(42, "test message");
Best Practices
Error Handling
/* Always check return values */
ret = device_register(&my_device);
if (ret) {
pr_err("Failed to register device: %d\n", ret);
goto err_register;
}
/* Use goto for cleanup */
err_register:
kfree(buffer);
err_alloc:
return ret;
/* Use devm_* functions for automatic cleanup */
priv = devm_kzalloc(dev, sizeof(*priv), GFP_KERNEL);
priv->base = devm_ioremap_resource(dev, res);
devm_request_irq(dev, irq, handler, flags, name, dev_id);
Memory Management
/* Use appropriate allocation flags */
/* GFP_KERNEL: Can sleep (process context) */
ptr = kmalloc(size, GFP_KERNEL);
/* GFP_ATOMIC: Cannot sleep (interrupt context) */
ptr = kmalloc(size, GFP_ATOMIC);
/* Always check for NULL */
if (!ptr)
return -ENOMEM;
/* Free memory */
kfree(ptr);
/* Use devm_* for automatic cleanup */
ptr = devm_kmalloc(dev, size, GFP_KERNEL);
/* No need to explicitly free */
Locking
/* Choose appropriate lock type */
/* Mutex: Can sleep, process context only */
mutex_lock(&priv->lock);
/* ... */
mutex_unlock(&priv->lock);
/* Spinlock: Cannot sleep, short critical sections */
spin_lock(&priv->lock);
/* ... */
spin_unlock(&priv->lock);
/* Spinlock with IRQ disable (accessed from IRQ) */
unsigned long flags;
spin_lock_irqsave(&priv->lock, flags);
/* ... */
spin_unlock_irqrestore(&priv->lock, flags);
Module Parameters
static int debug = 0;
module_param(debug, int, 0644);
MODULE_PARM_DESC(debug, "Enable debug output");
static char *mode = "auto";
module_param(mode, charp, 0444);
MODULE_PARM_DESC(mode, "Operating mode");
/* Use in code */
if (debug)
pr_info("Debug mode enabled\n");
Module Metadata
MODULE_LICENSE("GPL");
MODULE_AUTHOR("Your Name <your.email@example.com>");
MODULE_DESCRIPTION("Device driver for XYZ hardware");
MODULE_VERSION("1.0");
MODULE_ALIAS("platform:my-device");
Resources
- Linux Device Drivers (LDD3): https://lwn.net/Kernel/LDD3/
- Kernel Documentation:
Documentation/driver-api/in kernel source - Device Tree:
Documentation/devicetree/in kernel source - Example Drivers:
drivers/in kernel source tree - Linux Driver Development for Embedded Processors (Alberto Liberal)
- Essential Linux Device Drivers (Sreekrishnan Venkateswaran)
Linux driver development requires understanding of kernel internals, hardware interfaces, and proper resource management. Following best practices and using the kernel's device model framework ensures drivers are maintainable, efficient, and safe.
Device Tree
A comprehensive guide to Linux Device Tree, a data structure for describing hardware configuration that can be passed to the kernel at boot time.
Table of Contents
- Overview
- Why Device Tree?
- Device Tree Basics
- Device Tree Syntax
- Device Tree Structure
- Standard Properties
- Writing Device Tree Files
- Device Tree Compiler
- Parsing Device Tree in Drivers
- Common Bindings
- Platform-Specific Details
- Debugging Device Tree
- Best Practices
- Real-World Examples
Overview
Device Tree is a data structure and language for describing hardware that cannot be dynamically detected by the operating system. It's used extensively in embedded systems, especially ARM-based platforms.
Key Concepts
- Device Tree Source (.dts): Human-readable text file describing hardware
- Device Tree Blob (.dtb): Compiled binary format loaded by bootloader
- Device Tree Overlay (.dtbo): Runtime modifications to base device tree
- Bindings: Documentation defining properties for specific device types
Purpose
+-----------------+
| Bootloader |
| (U-Boot) |
+-----------------+
|
| Passes DTB
v
+-----------------+
| Linux Kernel |
+-----------------+
|
| Parses DT
v
+-----------------+
| Device Drivers |
+-----------------+
Device Tree allows:
- Hardware description separated from kernel code
- Single kernel binary supporting multiple boards
- Board-specific configuration without recompiling kernel
- Runtime hardware configuration via overlays
Why Device Tree?
Problems Device Tree Solves
Before Device Tree:
/* ARM board file - arch/arm/mach-vendor/board-xyz.c */
static struct platform_device uart0 = {
.name = "vendor-uart",
.id = 0,
.resource = {
.start = 0x44e09000,
.end = 0x44e09fff,
.flags = IORESOURCE_MEM,
},
.dev = {
.platform_data = &uart0_data,
},
};
platform_device_register(&uart0);
Problems:
- Board-specific code in kernel
- One kernel per board variant
- Difficult to maintain
- No standardization
With Device Tree:
uart0: serial@44e09000 {
compatible = "vendor,uart";
reg = <0x44e09000 0x1000>;
interrupts = <72>;
clock-frequency = <48000000>;
};
Benefits:
- Hardware description in separate file
- Single kernel for multiple boards
- Standardized bindings
- Easier to maintain
Device Tree Basics
Device Tree Hierarchy
Device Tree represents hardware as a tree of nodes:
/ {
model = "Vendor Board XYZ";
compatible = "vendor,board-xyz";
cpus {
cpu@0 {
compatible = "arm,cortex-a8";
device_type = "cpu";
reg = <0>;
};
};
memory@80000000 {
device_type = "memory";
reg = <0x80000000 0x20000000>; /* 512MB */
};
soc {
compatible = "simple-bus";
#address-cells = <1>;
#size-cells = <1>;
ranges;
uart0: serial@44e09000 {
compatible = "vendor,uart";
reg = <0x44e09000 0x1000>;
};
};
};
Key Terminology
- Node: Represents a device or bus (
uart0,cpus) - Property: Key-value pair in a node (
compatible = "vendor,uart") - Label: Reference to a node (
uart0:) - Phandle: Reference to another node (pointer to node)
- Unit Address: Address part after
@(serial@44e09000)
Device Tree Syntax
Basic Syntax
/* Comments use C-style syntax */
/ {
/* Root node - always present */
node-name {
/* Properties */
property-name = "string value";
another-property = <0x12345678>;
multi-value = <0x1 0x2 0x3>;
boolean-property; /* Presence indicates true */
};
node@unit-address {
/* Node with unit address */
reg = <0x12340000 0x1000>;
};
};
Property Value Types
/ {
/* String */
model = "Vendor Board XYZ";
/* String list */
compatible = "vendor,board-xyz", "vendor,board";
/* 32-bit unsigned integers (cells) */
reg = <0x44e09000 0x1000>;
/* Multiple cells */
interrupts = <0 72 4>;
/* Boolean (empty property) */
dma-coherent;
/* Byte sequence */
mac-address = [00 11 22 33 44 55];
/* Mixed */
property = "string", <0x1234>, [AB CD];
/* Phandle reference */
interrupt-parent = <&intc>;
clocks = <&osc 0>;
};
Cell Size Specifiers
/ {
#address-cells = <1>; /* Address takes 1 cell (32-bit) */
#size-cells = <1>; /* Size takes 1 cell */
soc {
#address-cells = <1>;
#size-cells = <1>;
/* reg = <address size> */
uart0: serial@44e09000 {
reg = <0x44e09000 0x1000>;
};
};
};
/ {
#address-cells = <2>; /* 64-bit addressing */
#size-cells = <2>;
memory@0 {
/* reg = <address-high address-low size-high size-low> */
reg = <0x00000000 0x80000000 0x00000000 0x40000000>;
};
};
Labels and References
/ {
/* Define label */
intc: interrupt-controller@48200000 {
compatible = "arm,gic";
reg = <0x48200000 0x1000>;
interrupt-controller;
#interrupt-cells = <3>;
};
uart0: serial@44e09000 {
compatible = "vendor,uart";
/* Reference using phandle */
interrupt-parent = <&intc>;
interrupts = <0 72 4>;
clocks = <&sysclk>;
};
};
Includes
/* Include common definitions */
/include/ "vendor-common.dtsi"
/* Or using C preprocessor */
#include "vendor-common.dtsi"
#include <dt-bindings/gpio/gpio.h>
/ {
compatible = "vendor,board";
};
Device Tree Structure
Complete Example
/dts-v1/;
/ {
model = "Vendor Development Board";
compatible = "vendor,dev-board", "vendor,soc";
#address-cells = <1>;
#size-cells = <1>;
chosen {
bootargs = "console=ttyS0,115200 root=/dev/mmcblk0p2";
stdout-path = "/serial@44e09000:115200n8";
};
memory@80000000 {
device_type = "memory";
reg = <0x80000000 0x40000000>; /* 1GB */
};
cpus {
#address-cells = <1>;
#size-cells = <0>;
cpu0: cpu@0 {
compatible = "arm,cortex-a8";
device_type = "cpu";
reg = <0>;
operating-points = <
/* kHz uV */
1000000 1350000
800000 1300000
600000 1200000
>;
clock-latency = <300000>; /* 300 us */
};
};
clocks {
osc: oscillator {
compatible = "fixed-clock";
#clock-cells = <0>;
clock-frequency = <24000000>;
};
sysclk: system-clock {
compatible = "fixed-clock";
#clock-cells = <0>;
clock-frequency = <48000000>;
};
};
soc {
compatible = "simple-bus";
#address-cells = <1>;
#size-cells = <1>;
ranges;
intc: interrupt-controller@48200000 {
compatible = "arm,cortex-a8-gic";
interrupt-controller;
#interrupt-cells = <3>;
reg = <0x48200000 0x1000>,
<0x48210000 0x2000>;
};
uart0: serial@44e09000 {
compatible = "vendor,uart", "ns16550a";
reg = <0x44e09000 0x1000>;
interrupt-parent = <&intc>;
interrupts = <0 72 4>;
clocks = <&sysclk>;
clock-names = "uart";
status = "okay";
};
i2c0: i2c@44e0b000 {
compatible = "vendor,i2c";
reg = <0x44e0b000 0x1000>;
interrupts = <0 70 4>;
#address-cells = <1>;
#size-cells = <0>;
clocks = <&sysclk>;
status = "okay";
/* I2C device */
eeprom@50 {
compatible = "atmel,24c256";
reg = <0x50>;
pagesize = <64>;
};
};
gpio0: gpio@44e07000 {
compatible = "vendor,gpio";
reg = <0x44e07000 0x1000>;
interrupts = <0 96 4>;
gpio-controller;
#gpio-cells = <2>;
interrupt-controller;
#interrupt-cells = <2>;
};
mmc0: mmc@48060000 {
compatible = "vendor,mmc";
reg = <0x48060000 0x1000>;
interrupts = <0 64 4>;
bus-width = <4>;
cd-gpios = <&gpio0 6 0>;
status = "okay";
};
};
leds {
compatible = "gpio-leds";
led0 {
label = "board:green:user0";
gpios = <&gpio0 21 0>;
linux,default-trigger = "heartbeat";
};
led1 {
label = "board:green:user1";
gpios = <&gpio0 22 0>;
default-state = "off";
};
};
regulators {
compatible = "simple-bus";
vdd_3v3: regulator@0 {
compatible = "regulator-fixed";
regulator-name = "vdd_3v3";
regulator-min-microvolt = <3300000>;
regulator-max-microvolt = <3300000>;
regulator-always-on;
};
};
};
Standard Properties
Compatible Property
The compatible property is the most important - it binds the node to a driver:
uart0: serial@44e09000 {
/* Most specific first, generic last */
compatible = "vendor,soc-uart", "vendor,uart", "ns16550a";
...
};
Driver matching:
static const struct of_device_id uart_of_match[] = {
{ .compatible = "vendor,soc-uart", .data = &soc_uart_data },
{ .compatible = "vendor,uart", .data = &generic_uart_data },
{ .compatible = "ns16550a", .data = &ns16550_data },
{ }
};
MODULE_DEVICE_TABLE(of, uart_of_match);
Reg Property
Specifies address ranges (MMIO, I2C address, SPI chip select):
/* MMIO register range */
uart0: serial@44e09000 {
reg = <0x44e09000 0x1000>; /* Base address, size */
};
/* Multiple ranges */
intc: interrupt-controller@48200000 {
reg = <0x48200000 0x1000>, /* Distributor */
<0x48210000 0x2000>; /* CPU interface */
};
/* I2C device */
eeprom@50 {
reg = <0x50>; /* I2C address */
};
/* SPI device */
flash@0 {
reg = <0>; /* Chip select 0 */
};
Status Property
Enables or disables devices:
uart0: serial@44e09000 {
status = "okay"; /* Enable */
};
uart1: serial@44e0a000 {
status = "disabled"; /* Disable */
};
uart2: serial@44e0b000 {
status = "fail"; /* Error detected */
};
Interrupt Properties
uart0: serial@44e09000 {
/* Parent interrupt controller */
interrupt-parent = <&intc>;
/* Interrupt specifier (format defined by parent) */
/* For GIC: <type number flags> */
interrupts = <0 72 4>; /* SPI, IRQ 72, level-high */
};
/* Shared interrupt */
device@0 {
interrupts = <0 50 4>;
interrupt-names = "tx", "rx", "error";
};
Clock Properties
uart0: serial@44e09000 {
clocks = <&sysclk>, <&pclk>;
clock-names = "uart", "apb_pclk";
};
/* Clock frequency for fixed clocks */
osc: oscillator {
compatible = "fixed-clock";
#clock-cells = <0>;
clock-frequency = <24000000>;
};
GPIO Properties
device {
/* GPIO specifier: <&controller pin flags> */
reset-gpios = <&gpio0 15 GPIO_ACTIVE_LOW>;
enable-gpios = <&gpio0 16 GPIO_ACTIVE_HIGH>;
};
#include <dt-bindings/gpio/gpio.h>
/* GPIO_ACTIVE_LOW, GPIO_ACTIVE_HIGH */
DMA Properties
uart0: serial@44e09000 {
dmas = <&dma 25>, <&dma 26>;
dma-names = "tx", "rx";
};
Writing Device Tree Files
Device Tree Source (.dts)
Board-specific file:
/dts-v1/;
#include "vendor-soc.dtsi"
/ {
model = "Vendor Board XYZ";
compatible = "vendor,board-xyz", "vendor,soc";
memory@80000000 {
device_type = "memory";
reg = <0x80000000 0x40000000>;
};
};
/* Enable and configure UART0 */
&uart0 {
status = "okay";
pinctrl-names = "default";
pinctrl-0 = <&uart0_pins>;
};
/* Disable UART1 (not used on this board) */
&uart1 {
status = "disabled";
};
/* Add I2C devices */
&i2c0 {
status = "okay";
clock-frequency = <400000>;
/* Board-specific I2C device */
rtc@68 {
compatible = "dallas,ds1307";
reg = <0x68>;
};
};
Device Tree Include (.dtsi)
SoC-level common definitions:
/* vendor-soc.dtsi */
/ {
#address-cells = <1>;
#size-cells = <1>;
cpus {
#address-cells = <1>;
#size-cells = <0>;
cpu@0 {
compatible = "arm,cortex-a8";
device_type = "cpu";
reg = <0>;
};
};
soc {
compatible = "simple-bus";
#address-cells = <1>;
#size-cells = <1>;
ranges;
uart0: serial@44e09000 {
compatible = "vendor,uart";
reg = <0x44e09000 0x1000>;
interrupts = <0 72 4>;
clocks = <&sysclk>;
status = "disabled"; /* Disabled by default */
};
uart1: serial@44e0a000 {
compatible = "vendor,uart";
reg = <0x44e0a000 0x1000>;
interrupts = <0 73 4>;
clocks = <&sysclk>;
status = "disabled";
};
i2c0: i2c@44e0b000 {
compatible = "vendor,i2c";
reg = <0x44e0b000 0x1000>;
interrupts = <0 70 4>;
#address-cells = <1>;
#size-cells = <0>;
clocks = <&sysclk>;
status = "disabled";
};
};
};
Overriding and Extending Nodes
/* Base definition in .dtsi */
&uart0 {
compatible = "vendor,uart";
reg = <0x44e09000 0x1000>;
status = "disabled";
};
/* Board-specific .dts */
&uart0 {
status = "okay";
pinctrl-names = "default";
pinctrl-0 = <&uart0_pins>;
/* Adds new properties while keeping existing ones */
};
Deleting Nodes/Properties
/* Delete property */
&uart0 {
/delete-property/ dmas;
/delete-property/ dma-names;
};
/* Delete node */
&uart1 {
/delete-node/ device@0;
};
Device Tree Compiler
Compiling Device Tree
# Compile .dts to .dtb
dtc -I dts -O dtb -o board.dtb board.dts
# With includes
dtc -I dts -O dtb -o board.dtb -i include_path board.dts
# Using C preprocessor
cpp -nostdinc -I include_path -undef -x assembler-with-cpp \
board.dts board.preprocessed.dts
dtc -I dts -O dtb -o board.dtb board.preprocessed.dts
Decompiling Device Tree
# Decompile .dtb to .dts
dtc -I dtb -O dts -o board.dts board.dtb
# With symbols for overlays
dtc -I dtb -O dts -o board.dts board.dtb -@
Building with Kernel
# In kernel Makefile
dtb-$(CONFIG_BOARD_XYZ) += board-xyz.dtb
# Build
make dtbs
# Output in: arch/arm/boot/dts/board-xyz.dtb
Validation
# Check syntax
dtc -I dts -O dtb -o /dev/null board.dts
# Validate against schema (Linux 5.4+)
make dt_binding_check
make dtbs_check
Parsing Device Tree in Drivers
Getting Device Tree Node
#include <linux/of.h>
#include <linux/of_device.h>
static int my_probe(struct platform_device *pdev)
{
struct device *dev = &pdev->dev;
struct device_node *np = dev->of_node;
if (!np) {
dev_err(dev, "No device tree node\n");
return -ENODEV;
}
/* Node is available */
return 0;
}
Reading Properties
/* Read string */
const char *model;
if (of_property_read_string(np, "model", &model) == 0) {
pr_info("Model: %s\n", model);
}
/* Read u32 */
u32 clock_freq;
if (of_property_read_u32(np, "clock-frequency", &clock_freq) == 0) {
pr_info("Clock: %u Hz\n", clock_freq);
}
/* Read u32 array */
u32 values[3];
int count = of_property_read_u32_array(np, "interrupts", values, 3);
/* Read u64 */
u64 reg_base;
of_property_read_u64(np, "reg", ®_base);
/* Check if property exists */
if (of_property_read_bool(np, "dma-coherent")) {
pr_info("DMA coherent enabled\n");
}
Getting Resources
/* Get memory resource */
struct resource *res;
res = platform_get_resource(pdev, IORESOURCE_MEM, 0);
if (!res)
return -ENODEV;
void __iomem *base = devm_ioremap_resource(dev, res);
if (IS_ERR(base))
return PTR_ERR(base);
/* Get IRQ */
int irq = platform_get_irq(pdev, 0);
if (irq < 0)
return irq;
/* Get register address/size directly */
u64 addr, size;
of_property_read_u64_index(np, "reg", 0, &addr);
of_property_read_u64_index(np, "reg", 1, &size);
Parsing Phandles
/* Get referenced node */
struct device_node *clk_np;
clk_np = of_parse_phandle(np, "clocks", 0);
if (!clk_np) {
dev_err(dev, "No clock specified\n");
return -EINVAL;
}
/* Get clock */
struct clk *clk = of_clk_get(np, 0);
if (IS_ERR(clk))
return PTR_ERR(clk);
/* Or by name */
clk = of_clk_get_by_name(np, "uart");
GPIO Handling
#include <linux/of_gpio.h>
/* Get GPIO */
int reset_gpio = of_get_named_gpio(np, "reset-gpios", 0);
if (!gpio_is_valid(reset_gpio))
return -EINVAL;
/* Request and configure */
devm_gpio_request_one(dev, reset_gpio, GPIOF_OUT_INIT_LOW, "reset");
/* Using GPIO descriptor API (preferred) */
#include <linux/gpio/consumer.h>
struct gpio_desc *reset_gpiod;
reset_gpiod = devm_gpiod_get(dev, "reset", GPIOD_OUT_LOW);
if (IS_ERR(reset_gpiod))
return PTR_ERR(reset_gpiod);
gpiod_set_value(reset_gpiod, 1);
Iterating Child Nodes
struct device_node *child;
for_each_child_of_node(np, child) {
const char *name;
u32 reg;
of_property_read_string(child, "label", &name);
of_property_read_u32(child, "reg", ®);
pr_info("Child: %s at 0x%x\n", name, reg);
}
Complete Driver Example
#include <linux/module.h>
#include <linux/platform_device.h>
#include <linux/of.h>
#include <linux/of_device.h>
#include <linux/clk.h>
#include <linux/gpio/consumer.h>
struct my_device {
void __iomem *base;
struct clk *clk;
int irq;
struct gpio_desc *reset_gpio;
u32 clock_freq;
};
static int my_probe(struct platform_device *pdev)
{
struct device *dev = &pdev->dev;
struct device_node *np = dev->of_node;
struct my_device *priv;
struct resource *res;
int ret;
priv = devm_kzalloc(dev, sizeof(*priv), GFP_KERNEL);
if (!priv)
return -ENOMEM;
/* Get memory resource */
res = platform_get_resource(pdev, IORESOURCE_MEM, 0);
priv->base = devm_ioremap_resource(dev, res);
if (IS_ERR(priv->base))
return PTR_ERR(priv->base);
/* Get IRQ */
priv->irq = platform_get_irq(pdev, 0);
if (priv->irq < 0)
return priv->irq;
/* Get clock */
priv->clk = devm_clk_get(dev, "uart");
if (IS_ERR(priv->clk)) {
dev_err(dev, "Failed to get clock\n");
return PTR_ERR(priv->clk);
}
/* Get GPIO */
priv->reset_gpio = devm_gpiod_get_optional(dev, "reset", GPIOD_OUT_LOW);
if (IS_ERR(priv->reset_gpio))
return PTR_ERR(priv->reset_gpio);
/* Read clock frequency */
ret = of_property_read_u32(np, "clock-frequency", &priv->clock_freq);
if (ret) {
/* Use default if not specified */
priv->clock_freq = 48000000;
}
/* Enable clock */
ret = clk_prepare_enable(priv->clk);
if (ret)
return ret;
/* Reset device */
if (priv->reset_gpio) {
gpiod_set_value(priv->reset_gpio, 1);
msleep(10);
gpiod_set_value(priv->reset_gpio, 0);
}
platform_set_drvdata(pdev, priv);
dev_info(dev, "Device initialized (clock=%u Hz)\n", priv->clock_freq);
return 0;
}
static int my_remove(struct platform_device *pdev)
{
struct my_device *priv = platform_get_drvdata(pdev);
clk_disable_unprepare(priv->clk);
return 0;
}
static const struct of_device_id my_of_match[] = {
{ .compatible = "vendor,my-device" },
{ }
};
MODULE_DEVICE_TABLE(of, my_of_match);
static struct platform_driver my_driver = {
.probe = my_probe,
.remove = my_remove,
.driver = {
.name = "my-device",
.of_match_table = my_of_match,
},
};
module_platform_driver(my_driver);
MODULE_LICENSE("GPL");
MODULE_AUTHOR("Your Name");
MODULE_DESCRIPTION("Device Tree Example Driver");
Device Tree:
my_device: device@44e09000 {
compatible = "vendor,my-device";
reg = <0x44e09000 0x1000>;
interrupts = <0 72 4>;
clocks = <&sysclk>;
clock-names = "uart";
reset-gpios = <&gpio0 15 GPIO_ACTIVE_LOW>;
clock-frequency = <48000000>;
};
Common Bindings
I2C Devices
&i2c0 {
#address-cells = <1>;
#size-cells = <0>;
eeprom@50 {
compatible = "atmel,24c256";
reg = <0x50>;
pagesize = <64>;
};
rtc@68 {
compatible = "dallas,ds1307";
reg = <0x68>;
interrupts = <0 75 IRQ_TYPE_EDGE_FALLING>;
};
};
SPI Devices
&spi0 {
#address-cells = <1>;
#size-cells = <0>;
flash@0 {
compatible = "jedec,spi-nor";
reg = <0>; /* Chip select 0 */
spi-max-frequency = <20000000>;
partitions {
compatible = "fixed-partitions";
#address-cells = <1>;
#size-cells = <1>;
partition@0 {
label = "bootloader";
reg = <0x000000 0x100000>;
read-only;
};
partition@100000 {
label = "kernel";
reg = <0x100000 0x400000>;
};
partition@500000 {
label = "rootfs";
reg = <0x500000 0xb00000>;
};
};
};
};
Regulators
regulators {
compatible = "simple-bus";
#address-cells = <1>;
#size-cells = <0>;
vdd_core: regulator@0 {
compatible = "regulator-fixed";
reg = <0>;
regulator-name = "vdd_core";
regulator-min-microvolt = <1200000>;
regulator-max-microvolt = <1200000>;
regulator-always-on;
regulator-boot-on;
};
vdd_3v3: regulator@1 {
compatible = "regulator-gpio";
reg = <1>;
regulator-name = "vdd_3v3";
regulator-min-microvolt = <3300000>;
regulator-max-microvolt = <3300000>;
enable-gpio = <&gpio0 20 GPIO_ACTIVE_HIGH>;
enable-active-high;
};
};
/* Usage */
&uart0 {
vdd-supply = <&vdd_3v3>;
};
Pinctrl (Pin Multiplexing)
pinctrl: pinctrl@44e10800 {
compatible = "vendor,pinctrl";
reg = <0x44e10800 0x1000>;
uart0_pins: uart0_pins {
pinctrl-single,pins = <
0x170 (PIN_INPUT_PULLUP | MUX_MODE0) /* uart0_rxd */
0x174 (PIN_OUTPUT_PULLDOWN | MUX_MODE0) /* uart0_txd */
>;
};
i2c0_pins: i2c0_pins {
pinctrl-single,pins = <
0x188 (PIN_INPUT_PULLUP | MUX_MODE0) /* i2c0_sda */
0x18c (PIN_INPUT_PULLUP | MUX_MODE0) /* i2c0_scl */
>;
};
};
&uart0 {
pinctrl-names = "default";
pinctrl-0 = <&uart0_pins>;
};
&i2c0 {
pinctrl-names = "default";
pinctrl-0 = <&i2c0_pins>;
};
Platform-Specific Details
ARM Device Tree
/dts-v1/;
/ {
model = "ARM Versatile Express";
compatible = "arm,vexpress";
#address-cells = <1>;
#size-cells = <1>;
cpus {
#address-cells = <1>;
#size-cells = <0>;
cpu@0 {
device_type = "cpu";
compatible = "arm,cortex-a9";
reg = <0>;
};
cpu@1 {
device_type = "cpu";
compatible = "arm,cortex-a9";
reg = <1>;
};
};
};
ARM64 Device Tree
/dts-v1/;
/ {
#address-cells = <2>; /* 64-bit addressing */
#size-cells = <2>;
cpus {
#address-cells = <1>;
#size-cells = <0>;
cpu@0 {
device_type = "cpu";
compatible = "arm,cortex-a57";
reg = <0x0>;
enable-method = "psci";
};
};
memory@80000000 {
device_type = "memory";
reg = <0x0 0x80000000 0x0 0x80000000>; /* 2GB */
};
};
Raspberry Pi Example
/dts-v1/;
#include "bcm2835.dtsi"
/ {
compatible = "raspberrypi,model-b", "brcm,bcm2835";
model = "Raspberry Pi Model B";
memory@0 {
device_type = "memory";
reg = <0 0x20000000>; /* 512 MB */
};
};
&uart0 {
status = "okay";
};
&i2c1 {
status = "okay";
clock-frequency = <100000>;
};
&sdhci {
status = "okay";
bus-width = <4>;
};
Debugging Device Tree
Viewing Loaded Device Tree
# View device tree in /proc
cat /proc/device-tree/model
# Or using dtc
dtc -I fs -O dts /proc/device-tree
# Better formatting
dtc -I fs -O dts -o /tmp/current.dts /proc/device-tree
Sysfs Device Tree
# Navigate device tree in sysfs
ls /sys/firmware/devicetree/base/
# View property
cat /sys/firmware/devicetree/base/model
# View all properties of a node
ls -la /sys/firmware/devicetree/base/soc/serial@44e09000/
Kernel Boot Messages
# Check device tree loading
dmesg | grep -i "device tree"
dmesg | grep -i "dtb"
# Check OF (Open Firmware) messages
dmesg | grep -i "of:"
Driver Matching Debug
/* In driver code */
static int my_probe(struct platform_device *pdev)
{
struct device *dev = &pdev->dev;
struct device_node *np = dev->of_node;
dev_info(dev, "Device tree node: %pOF\n", np);
dev_info(dev, "Compatible: %s\n",
of_get_property(np, "compatible", NULL));
/* Print all properties */
struct property *prop;
for_each_property_of_node(np, prop) {
dev_info(dev, "Property: %s\n", prop->name);
}
return 0;
}
Common Issues
Device not probing:
# Check if device is in device tree
ls /sys/firmware/devicetree/base/soc/
# Check driver registration
ls /sys/bus/platform/drivers/
# Check devices without driver
cat /sys/kernel/debug/devices_deferred
Compatible string mismatch:
/* Check driver's compatible strings */
static const struct of_device_id my_of_match[] = {
{ .compatible = "vendor,device-v2" }, /* Try this first */
{ .compatible = "vendor,device" }, /* Then this */
{ }
};
Best Practices
DO's
- Use specific compatible strings first:
compatible = "vendor,soc-uart-v2", "vendor,uart", "ns16550a";
- Disable devices by default in SoC .dtsi:
/* In SoC .dtsi */
uart0: serial@44e09000 {
status = "disabled";
};
/* In board .dts */
&uart0 {
status = "okay";
};
- Use labels for references:
uart0: serial@44e09000 { ... };
&uart0 {
/* Override properties */
};
- Document bindings:
# Documentation/devicetree/bindings/serial/vendor-uart.yaml
title: Vendor UART Controller
properties:
compatible:
const: vendor,uart
reg:
maxItems: 1
interrupts:
maxItems: 1
- Use standard property names:
clock-frequencynotclock-freqreset-gpiosnotreset-gpio- Follow bindings in
Documentation/devicetree/bindings/
DON'Ts
- Don't duplicate information:
/* Bad - IRQ already specified in interrupts */
uart0 {
interrupts = <72>;
irq-number = <72>; /* Redundant */
};
/* Good */
uart0 {
interrupts = <72>;
};
- Don't use Linux-specific information:
/* Bad - driver name is Linux-specific */
uart0 {
linux,driver-name = "vendor-uart";
};
/* Good - use compatible */
uart0 {
compatible = "vendor,uart";
};
- Don't hardcode board-specific data in drivers:
/* Bad - hardcoded in driver */
#define UART_BASE 0x44e09000
/* Good - read from device tree */
res = platform_get_resource(pdev, IORESOURCE_MEM, 0);
Real-World Examples
BeagleBone Black
/dts-v1/;
#include "am33xx.dtsi"
/ {
model = "TI AM335x BeagleBone Black";
compatible = "ti,am335x-bone-black", "ti,am335x-bone", "ti,am33xx";
memory@80000000 {
device_type = "memory";
reg = <0x80000000 0x20000000>; /* 512 MB */
};
leds {
compatible = "gpio-leds";
pinctrl-names = "default";
pinctrl-0 = <&user_leds_s0>;
led0 {
label = "beaglebone:green:usr0";
gpios = <&gpio1 21 GPIO_ACTIVE_HIGH>;
linux,default-trigger = "heartbeat";
default-state = "off";
};
};
};
&uart0 {
pinctrl-names = "default";
pinctrl-0 = <&uart0_pins>;
status = "okay";
};
&mmc1 {
vmmc-supply = <&vmmcsd_fixed>;
bus-width = <4>;
status = "okay";
};
Raspberry Pi 4
/dts-v1/;
#include "bcm2711.dtsi"
/ {
compatible = "raspberrypi,4-model-b", "brcm,bcm2711";
model = "Raspberry Pi 4 Model B";
memory@0 {
device_type = "memory";
reg = <0x0 0x0 0x0 0x80000000>; /* 2GB */
};
aliases {
serial0 = &uart0;
serial1 = &uart1;
};
};
&uart0 {
pinctrl-names = "default";
pinctrl-0 = <&uart0_gpio14>;
status = "okay";
};
&i2c1 {
pinctrl-names = "default";
pinctrl-0 = <&i2c1_gpio2>;
clock-frequency = <100000>;
status = "okay";
};
Summary
Device Tree provides:
- Hardware description separated from kernel code
- Single kernel for multiple boards
- Runtime configuration
- Standardized hardware description
Key points:
- Use
.dtsfor board-specific,.dtsifor SoC common definitions compatibleproperty binds nodes to drivers- Use standard properties and follow bindings documentation
- Parse device tree in drivers using OF APIs
- Debug using
/proc/device-treeand/sys/firmware/devicetree
Resources:
Cross Compilation
A comprehensive guide to cross compilation for Linux - building software on one platform (host) to run on a different platform (target).
Table of Contents
- Overview
- Why Cross Compilation?
- Terminology
- Toolchain Setup
- Cross Compiling the Linux Kernel
- Cross Compiling User Space Applications
- Build System Support
- Root Filesystem Creation
- Debugging Cross-Compiled Code
- Common Architectures
- Troubleshooting
- Best Practices
Overview
Cross Compilation is the process of building executable code on one system (the host) that will run on a different system (the target). This is essential for embedded systems development where the target device may have limited resources or a different architecture.
Typical Scenario
┌─────────────────────┐ ┌─────────────────────┐
│ Host System │ │ Target System │
│ x86_64 Linux │ │ ARM Cortex-A8 │
│ Development PC │ ────> │ Embedded Board │
│ │ │ │
│ - Fast CPU │ │ - Slow CPU │
│ - Lots of RAM │ │ - Limited RAM │
│ - Large Storage │ │ - Small Storage │
└─────────────────────┘ └─────────────────────┘
Build on host → Deploy to target
Why Cross Compilation?
Reasons for Cross Compilation
-
Limited Target Resources
- Embedded devices lack CPU power, RAM, or storage for compilation
- Building natively would take hours or fail due to memory constraints
-
Architecture Differences
- Development machine (x86_64) differs from target (ARM, MIPS, etc.)
- Cannot run x86 binaries on ARM without emulation
-
Speed
- Powerful development machine compiles much faster than embedded target
- Native compilation on Raspberry Pi: 2 hours → Cross compilation: 10 minutes
-
Tooling
- Better development tools available on host
- Easier debugging and profiling setup
-
Consistency
- Reproducible builds across team
- Controlled toolchain versions
Example: Raspberry Pi
Native compilation on Pi 3:
# Building Linux kernel natively
$ time make -j4
real 120m0.000s # 2 hours!
Cross compilation on x86_64 PC:
# Cross compiling same kernel
$ time make ARCH=arm CROSS_COMPILE=arm-linux-gnueabihf- -j8
real 10m0.000s # 10 minutes!
Terminology
Key Terms
-
Host: System where compilation happens (your development PC)
-
Build: System where build tools run (usually same as host)
-
Target: System where compiled code will run (embedded device)
-
Toolchain: Collection of tools for cross compilation
- Compiler (gcc, clang)
- Linker (ld)
- Assembler (as)
- Libraries (libc, libgcc)
- Utilities (objcopy, objdump, strip)
-
Triple/Tuple: Architecture specification format
- Format:
arch-vendor-os-abi - Example:
arm-linux-gnueabihfarm: Architecture (ARM)linux: OS (Linux)gnueabihf: ABI (GNU EABI Hard Float)
- Format:
-
Sysroot: Target system's root filesystem on host
- Contains target's headers and libraries
- Located on development machine
- Mimics target's
/usr,/lib, etc.
Architecture Tuples
# Common architecture tuples
arm-linux-gnueabi # ARM soft-float
arm-linux-gnueabihf # ARM hard-float
aarch64-linux-gnu # ARM 64-bit
mips-linux-gnu # MIPS
mipsel-linux-gnu # MIPS little-endian
powerpc-linux-gnu # PowerPC
x86_64-w64-mingw32 # Windows 64-bit
Toolchain Setup
Option 1: Pre-built Toolchains
Install from Distribution:
# Ubuntu/Debian - ARM
sudo apt-get install gcc-arm-linux-gnueabihf g++-arm-linux-gnueabihf
# ARM64
sudo apt-get install gcc-aarch64-linux-gnu g++-aarch64-linux-gnu
# MIPS
sudo apt-get install gcc-mips-linux-gnu g++-mips-linux-gnu
# Verify installation
arm-linux-gnueabihf-gcc --version
Linaro Toolchains:
# Download from Linaro
wget https://releases.linaro.org/components/toolchain/binaries/latest-7/arm-linux-gnueabihf/gcc-linaro-7.5.0-2019.12-x86_64_arm-linux-gnueabihf.tar.xz
# Extract
tar xf gcc-linaro-7.5.0-2019.12-x86_64_arm-linux-gnueabihf.tar.xz
# Add to PATH
export PATH=$PATH:$(pwd)/gcc-linaro-7.5.0-2019.12-x86_64_arm-linux-gnueabihf/bin
# Test
arm-linux-gnueabihf-gcc --version
Option 2: Crosstool-NG
Build custom toolchains:
# Install crosstool-NG
git clone https://github.com/crosstool-ng/crosstool-ng
cd crosstool-ng
./bootstrap
./configure --prefix=/opt/crosstool-ng
make
sudo make install
# Add to PATH
export PATH=/opt/crosstool-ng/bin:$PATH
# Configure and build
ct-ng list-samples
ct-ng arm-unknown-linux-gnueabi
ct-ng menuconfig # Configure as needed
ct-ng build
# Toolchain installed in ~/x-tools/arm-unknown-linux-gnueabi/
Option 3: Buildroot
Creates complete embedded Linux system including toolchain:
# Download Buildroot
wget https://buildroot.org/downloads/buildroot-2023.02.tar.gz
tar xf buildroot-2023.02.tar.gz
cd buildroot-2023.02
# Configure
make menuconfig
# Target options -> Target Architecture -> ARM
# Toolchain -> Build toolchain
# Build
make
# Toolchain in output/host/usr/bin/
export PATH=$PATH:$(pwd)/output/host/usr/bin
Setting Up Environment
Permanent setup:
# Add to ~/.bashrc or ~/.zshrc
export CROSS_COMPILE=arm-linux-gnueabihf-
export ARCH=arm
export PATH=$PATH:/path/to/toolchain/bin
# Apply
source ~/.bashrc
Project-specific:
# Create toolchain.env
cat > toolchain.env << 'EOF'
export CROSS_COMPILE=arm-linux-gnueabihf-
export ARCH=arm
export PATH=/opt/arm-toolchain/bin:$PATH
export SYSROOT=/opt/arm-sysroot
EOF
# Source when needed
source toolchain.env
Verifying Toolchain
# Check compiler
${CROSS_COMPILE}gcc --version
${CROSS_COMPILE}gcc -v
# Check target
${CROSS_COMPILE}gcc -dumpmachine
# Output: arm-linux-gnueabihf
# List all tools
ls -la $(dirname $(which ${CROSS_COMPILE}gcc))/${CROSS_COMPILE}*
Cross Compiling the Linux Kernel
Basic Kernel Cross Compilation
# Get kernel source
git clone https://github.com/torvalds/linux.git
cd linux
# Clean
make mrproper
# Configure for ARM (example: Versatile Express)
make ARCH=arm vexpress_defconfig
# Or use menuconfig
make ARCH=arm menuconfig
# Build
make ARCH=arm CROSS_COMPILE=arm-linux-gnueabihf- -j$(nproc)
# Build specific targets
make ARCH=arm CROSS_COMPILE=arm-linux-gnueabihf- zImage
make ARCH=arm CROSS_COMPILE=arm-linux-gnueabihf- modules
make ARCH=arm CROSS_COMPILE=arm-linux-gnueabihf- dtbs
# Install modules to staging directory
make ARCH=arm CROSS_COMPILE=arm-linux-gnueabihf- \
INSTALL_MOD_PATH=/path/to/rootfs modules_install
Raspberry Pi Kernel
# Clone Raspberry Pi kernel
git clone --depth=1 https://github.com/raspberrypi/linux
cd linux
# Pi 1, Zero, Zero W (32-bit)
make ARCH=arm CROSS_COMPILE=arm-linux-gnueabihf- bcmrpi_defconfig
# Pi 2, 3, 4 (32-bit)
make ARCH=arm CROSS_COMPILE=arm-linux-gnueabihf- bcm2709_defconfig
# Pi 3, 4 (64-bit)
make ARCH=arm64 CROSS_COMPILE=aarch64-linux-gnu- bcm2711_defconfig
# Build
make ARCH=arm CROSS_COMPILE=arm-linux-gnueabihf- zImage modules dtbs -j$(nproc)
# Install to SD card
export ROOTFS=/mnt/ext4
export BOOTFS=/mnt/fat32
# Install modules
sudo make ARCH=arm CROSS_COMPILE=arm-linux-gnueabihf- \
INSTALL_MOD_PATH=$ROOTFS modules_install
# Install kernel
sudo cp arch/arm/boot/zImage $BOOTFS/kernel7.img
sudo cp arch/arm/boot/dts/*.dtb $BOOTFS/
sudo cp arch/arm/boot/dts/overlays/*.dtb* $BOOTFS/overlays/
BeagleBone Black Kernel
# Clone kernel
git clone https://github.com/beagleboard/linux.git
cd linux
# Checkout stable branch
git checkout 5.10
# Configure
make ARCH=arm CROSS_COMPILE=arm-linux-gnueabihf- bb.org_defconfig
# Build
make ARCH=arm CROSS_COMPILE=arm-linux-gnueabihf- zImage modules dtbs -j$(nproc)
# Create uImage (U-Boot format)
make ARCH=arm CROSS_COMPILE=arm-linux-gnueabihf- uImage \
LOADADDR=0x80008000
# Install
sudo cp arch/arm/boot/uImage /media/$USER/BOOT/
sudo cp arch/arm/boot/dts/am335x-boneblack.dtb /media/$USER/BOOT/
sudo make ARCH=arm CROSS_COMPILE=arm-linux-gnueabihf- \
INSTALL_MOD_PATH=/media/$USER/rootfs modules_install
Kernel Configuration Tips
# Use existing config from target
scp user@target:/proc/config.gz .
zcat config.gz > .config
make ARCH=arm CROSS_COMPILE=arm-linux-gnueabihf- oldconfig
# Save custom config
make ARCH=arm CROSS_COMPILE=arm-linux-gnueabihf- savedefconfig
cp defconfig arch/arm/configs/myboard_defconfig
# Enable specific features
./scripts/config --enable CONFIG_FEATURE_NAME
./scripts/config --disable CONFIG_FEATURE_NAME
./scripts/config --module CONFIG_FEATURE_NAME
Cross Compiling User Space Applications
Simple C Program
/* hello.c */
#include <stdio.h>
int main(void)
{
printf("Hello from %s!\n",
#ifdef __arm__
"ARM"
#elif __aarch64__
"ARM64"
#elif __mips__
"MIPS"
#else
"unknown"
#endif
);
return 0;
}
Compile:
# Cross compile
arm-linux-gnueabihf-gcc hello.c -o hello
# Check architecture
file hello
# hello: ELF 32-bit LSB executable, ARM, version 1 (SYSV)
# Check dynamic libraries
arm-linux-gnueabihf-readelf -d hello | grep NEEDED
Static vs Dynamic Linking
Dynamic linking (default):
# Requires target's libc at runtime
arm-linux-gnueabihf-gcc hello.c -o hello
# List dependencies
arm-linux-gnueabihf-ldd hello
# or
arm-linux-gnueabihf-readelf -d hello
Static linking:
# Includes all libraries in binary
arm-linux-gnueabihf-gcc hello.c -o hello -static
# Check - no dependencies
file hello
# hello: ELF 32-bit LSB executable, ARM, statically linked
# Size comparison
ls -lh hello
# Much larger with static linking
Cross Compiling with Libraries
/* http_client.c - requires libcurl */
#include <curl/curl.h>
#include <stdio.h>
int main(void)
{
CURL *curl = curl_easy_init();
if (curl) {
curl_easy_cleanup(curl);
printf("libcurl working!\n");
}
return 0;
}
Without sysroot (will fail):
arm-linux-gnueabihf-gcc http_client.c -o http_client -lcurl
# Error: curl/curl.h: No such file or directory
With sysroot:
# Install target libraries on host
sudo apt-get install libcurl4-openssl-dev:armhf
# Compile with sysroot
arm-linux-gnueabihf-gcc http_client.c -o http_client \
--sysroot=/usr/arm-linux-gnueabihf \
-lcurl
# Or set PKG_CONFIG
export PKG_CONFIG_PATH=/usr/arm-linux-gnueabihf/lib/pkgconfig
arm-linux-gnueabihf-gcc http_client.c -o http_client \
$(pkg-config --cflags --libs libcurl)
Makefile for Cross Compilation
# Makefile
CC := $(CROSS_COMPILE)gcc
CXX := $(CROSS_COMPILE)g++
LD := $(CROSS_COMPILE)ld
AR := $(CROSS_COMPILE)ar
STRIP := $(CROSS_COMPILE)strip
CFLAGS := -Wall -O2
LDFLAGS :=
# Add sysroot if set
ifdef SYSROOT
CFLAGS += --sysroot=$(SYSROOT)
LDFLAGS += --sysroot=$(SYSROOT)
endif
TARGET := myapp
SRCS := main.c utils.c
OBJS := $(SRCS:.c=.o)
all: $(TARGET)
$(TARGET): $(OBJS)
$(CC) $(LDFLAGS) -o $@ $^
$(STRIP) $@
%.o: %.c
$(CC) $(CFLAGS) -c -o $@ $<
clean:
rm -f $(OBJS) $(TARGET)
.PHONY: all clean
Usage:
# Native compilation
make
# Cross compilation
make CROSS_COMPILE=arm-linux-gnueabihf-
# With sysroot
make CROSS_COMPILE=arm-linux-gnueabihf- SYSROOT=/opt/arm-sysroot
Build System Support
Autotools (./configure)
# Basic cross compilation
./configure --host=arm-linux-gnueabihf --prefix=/usr
# With sysroot
./configure \
--host=arm-linux-gnueabihf \
--prefix=/usr \
--with-sysroot=/opt/arm-sysroot \
CFLAGS="--sysroot=/opt/arm-sysroot" \
LDFLAGS="--sysroot=/opt/arm-sysroot"
# Build and install
make
make DESTDIR=/path/to/rootfs install
config.site for consistent configuration:
# Create config.site
cat > arm-config.site << 'EOF'
# Cross compilation settings
ac_cv_func_malloc_0_nonnull=yes
ac_cv_func_realloc_0_nonnull=yes
EOF
# Use it
./configure --host=arm-linux-gnueabihf --prefix=/usr \
CONFIG_SITE=arm-config.site
CMake
Toolchain file:
# arm-toolchain.cmake
set(CMAKE_SYSTEM_NAME Linux)
set(CMAKE_SYSTEM_PROCESSOR arm)
# Specify the cross compiler
set(CMAKE_C_COMPILER arm-linux-gnueabihf-gcc)
set(CMAKE_CXX_COMPILER arm-linux-gnueabihf-g++)
# Where to look for libraries
set(CMAKE_FIND_ROOT_PATH /usr/arm-linux-gnueabihf)
# Search for programs in the build host directories
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
# Search for libraries and headers in target directories
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
Build:
mkdir build && cd build
cmake .. -DCMAKE_TOOLCHAIN_FILE=../arm-toolchain.cmake
make
Or using environment variables:
export CC=arm-linux-gnueabihf-gcc
export CXX=arm-linux-gnueabihf-g++
cmake ..
make
Meson
Cross file:
# arm-cross.txt
[binaries]
c = 'arm-linux-gnueabihf-gcc'
cpp = 'arm-linux-gnueabihf-g++'
ar = 'arm-linux-gnueabihf-ar'
strip = 'arm-linux-gnueabihf-strip'
pkgconfig = 'arm-linux-gnueabihf-pkg-config'
[host_machine]
system = 'linux'
cpu_family = 'arm'
cpu = 'cortex-a8'
endian = 'little'
[properties]
sys_root = '/usr/arm-linux-gnueabihf'
Build:
meson setup build --cross-file arm-cross.txt
ninja -C build
Root Filesystem Creation
Using Buildroot
# Configure
make menuconfig
# Filesystem images -> ext2/3/4 root filesystem
# System configuration -> System hostname, root password
# Target packages -> Select packages
# Build
make
# Output
ls output/images/
# rootfs.ext4 zImage *.dtb
Using Yocto/OpenEmbedded
# Clone Poky
git clone -b kirkstone git://git.yoctoproject.org/poky
cd poky
# Initialize build
source oe-init-build-env
# Edit conf/local.conf
# MACHINE = "beaglebone-yocto"
# Build minimal image
bitbake core-image-minimal
# Output in tmp/deploy/images/beaglebone-yocto/
Manual Root Filesystem
#!/bin/bash
# create-rootfs.sh
ROOTFS=/tmp/arm-rootfs
TOOLCHAIN=arm-linux-gnueabihf
# Create directory structure
mkdir -p $ROOTFS/{bin,sbin,etc,proc,sys,dev,lib,usr/{bin,sbin,lib},tmp,var,home,root}
# Copy libraries from toolchain
SYSROOT=$(${TOOLCHAIN}-gcc -print-sysroot)
cp -a $SYSROOT/lib/* $ROOTFS/lib/
cp -a $SYSROOT/usr/lib/* $ROOTFS/usr/lib/
# Install busybox (provides basic utilities)
git clone https://git.busybox.net/busybox
cd busybox
make ARCH=arm CROSS_COMPILE=$TOOLCHAIN- defconfig
make ARCH=arm CROSS_COMPILE=$TOOLCHAIN- -j$(nproc)
make ARCH=arm CROSS_COMPILE=$TOOLCHAIN- \
CONFIG_PREFIX=$ROOTFS install
cd ..
# Create device nodes
sudo mknod -m 666 $ROOTFS/dev/null c 1 3
sudo mknod -m 666 $ROOTFS/dev/console c 5 1
sudo mknod -m 666 $ROOTFS/dev/tty c 5 0
# Create /etc/inittab
cat > $ROOTFS/etc/inittab << 'EOF'
::sysinit:/etc/init.d/rcS
::respawn:/sbin/getty 115200 console
::shutdown:/bin/umount -a -r
::restart:/sbin/init
EOF
# Create init script
mkdir -p $ROOTFS/etc/init.d
cat > $ROOTFS/etc/init.d/rcS << 'EOF'
#!/bin/sh
mount -t proc none /proc
mount -t sysfs none /sys
mount -t tmpfs none /tmp
echo "Boot complete"
EOF
chmod +x $ROOTFS/etc/init.d/rcS
# Create filesystem image
dd if=/dev/zero of=rootfs.ext4 bs=1M count=512
mkfs.ext4 rootfs.ext4
mkdir -p /mnt/rootfs
sudo mount rootfs.ext4 /mnt/rootfs
sudo cp -a $ROOTFS/* /mnt/rootfs/
sudo umount /mnt/rootfs
echo "Root filesystem created: rootfs.ext4"
Debugging Cross-Compiled Code
Remote GDB Debugging
On target (ARM device):
# Install gdbserver (if not already present)
# Run application under gdbserver
gdbserver :1234 ./myapp arg1 arg2
On host (development PC):
# Use cross-gdb
arm-linux-gnueabihf-gdb ./myapp
# In GDB
(gdb) target remote target-ip:1234
(gdb) break main
(gdb) continue
(gdb) step
(gdb) print variable
(gdb) backtrace
GDB script for convenience:
# .gdbinit
target remote 192.168.1.100:1234
break main
QEMU User Mode
Run ARM binaries on x86 using QEMU:
# Install QEMU user mode
sudo apt-get install qemu-user-static
# Run ARM binary
qemu-arm-static -L /usr/arm-linux-gnueabihf ./hello
# With GDB
qemu-arm-static -L /usr/arm-linux-gnueabihf -g 1234 ./hello
# In another terminal
arm-linux-gnueabihf-gdb ./hello
(gdb) target remote :1234
QEMU System Mode
Emulate entire ARM system:
# Install QEMU system
sudo apt-get install qemu-system-arm
# Run with kernel and rootfs
qemu-system-arm \
-M vexpress-a9 \
-kernel zImage \
-dtb vexpress-v2p-ca9.dtb \
-drive file=rootfs.ext4,if=sd,format=raw \
-append "console=ttyAMA0 root=/dev/mmcblk0 rootwait" \
-serial stdio \
-net nic -net user
Analyzing Binaries
# Check architecture
file myapp
arm-linux-gnueabihf-readelf -h myapp
# List symbols
arm-linux-gnueabihf-nm myapp
# Disassemble
arm-linux-gnueabihf-objdump -d myapp
# Check shared library dependencies
arm-linux-gnueabihf-readelf -d myapp | grep NEEDED
# Strings in binary
arm-linux-gnueabihf-strings myapp
# Size information
arm-linux-gnueabihf-size myapp
Common Architectures
ARM (32-bit)
# Soft-float (no FPU)
CROSS_COMPILE=arm-linux-gnueabi-
ARCH=arm
# Hard-float (with FPU)
CROSS_COMPILE=arm-linux-gnueabihf-
ARCH=arm
# Kernel config
make ARCH=arm multi_v7_defconfig
ARM64 (AArch64)
CROSS_COMPILE=aarch64-linux-gnu-
ARCH=arm64
# Kernel config
make ARCH=arm64 defconfig
MIPS
# Big-endian
CROSS_COMPILE=mips-linux-gnu-
ARCH=mips
# Little-endian
CROSS_COMPILE=mipsel-linux-gnu-
ARCH=mips
# Kernel config
make ARCH=mips malta_defconfig
RISC-V
# 64-bit
CROSS_COMPILE=riscv64-linux-gnu-
ARCH=riscv
# 32-bit
CROSS_COMPILE=riscv32-linux-gnu-
ARCH=riscv
# Kernel config
make ARCH=riscv defconfig
PowerPC
CROSS_COMPILE=powerpc-linux-gnu-
ARCH=powerpc
# Kernel config
make ARCH=powerpc pmac32_defconfig
Troubleshooting
Common Issues
Issue: "No such file or directory" for header files
# Problem: Headers not found
arm-linux-gnueabihf-gcc test.c
# test.c:1:10: fatal error: curl/curl.h: No such file or directory
# Solution: Install cross-compiled development package
sudo apt-get install libcurl4-openssl-dev:armhf
# Or specify include path
arm-linux-gnueabihf-gcc test.c \
-I/usr/arm-linux-gnueabihf/include
Issue: "cannot find -lxxx" linker errors
# Problem: Library not found
# /usr/bin/ld: cannot find -lssl
# Solution: Install library for target architecture
sudo apt-get install libssl-dev:armhf
# Or specify library path
arm-linux-gnueabihf-gcc test.c -lssl \
-L/usr/arm-linux-gnueabihf/lib
Issue: Binary runs on host but not target
# Check architecture
file myapp
# If says x86_64 instead of ARM, CROSS_COMPILE wasn't set
# Verify you're using cross compiler
which ${CROSS_COMPILE}gcc
# Check if it's stripped of debug info
${CROSS_COMPILE}readelf -S myapp | grep debug
Issue: "Exec format error" on target
# Problem: Wrong architecture or ABI mismatch
# Check target's actual architecture
ssh target 'uname -m' # armv7l, aarch64, etc.
# Check binary architecture
file myapp
# For ARM: Check float ABI
${CROSS_COMPILE}readelf -A myapp | grep ABI
# Must match target's ABI (soft-float vs hard-float)
Issue: Shared library not found on target
# Error on target
./myapp: error while loading shared libraries: libfoo.so.1
# Solution 1: Copy library to target
scp /usr/arm-linux-gnueabihf/lib/libfoo.so.* target:/lib/
# Solution 2: Static linking
arm-linux-gnueabihf-gcc test.c -o myapp -static
# Solution 3: Use LD_LIBRARY_PATH on target
export LD_LIBRARY_PATH=/path/to/libs:$LD_LIBRARY_PATH
Debugging Tips
# Verbose compiler output
arm-linux-gnueabihf-gcc -v test.c
# Show search paths
arm-linux-gnueabihf-gcc -print-search-dirs
# Show sysroot
arm-linux-gnueabihf-gcc -print-sysroot
# Preprocessor output only
arm-linux-gnueabihf-gcc -E test.c
# Show include paths
echo | arm-linux-gnueabihf-gcc -v -E -
# Test if toolchain works
arm-linux-gnueabihf-gcc -v
Best Practices
1. Use Consistent Toolchain
# Bad: Mixing toolchains
gcc myapp.c # Native compiler!
arm-linux-gnueabihf-gcc mylib.c
# Good: Use CROSS_COMPILE consistently
export CROSS_COMPILE=arm-linux-gnueabihf-
${CROSS_COMPILE}gcc myapp.c mylib.c
2. Separate Build Directories
# Keep source clean
mkdir -p build/arm build/x86
# ARM build
make O=build/arm ARCH=arm CROSS_COMPILE=arm-linux-gnueabihf-
# x86 build
make O=build/x86
# Clean specific build
rm -rf build/arm
3. Use Build Scripts
#!/bin/bash
# build-cross.sh
set -e # Exit on error
# Configuration
export ARCH=arm
export CROSS_COMPILE=arm-linux-gnueabihf-
export INSTALL_PATH=/opt/target-root
# Build
echo "Building for $ARCH..."
make clean
make -j$(nproc)
make install DESTDIR=$INSTALL_PATH
echo "Build complete: $INSTALL_PATH"
4. Maintain Sysroot
# Organized sysroot
/opt/arm-sysroot/
├── usr/
│ ├── include/ # Headers
│ └── lib/ # Libraries
├── lib/ # System libraries
└── etc/ # Configuration files
# Set PKG_CONFIG for libraries
export PKG_CONFIG_PATH=/opt/arm-sysroot/usr/lib/pkgconfig
export PKG_CONFIG_SYSROOT_DIR=/opt/arm-sysroot
5. Version Control Binaries
# Tag releases
git tag -a v1.0-arm -m "ARM release v1.0"
# Separate binary artifacts
artifacts/
├── v1.0/
│ ├── arm/
│ │ ├── myapp
│ │ └── myapp.debug
│ ├── arm64/
│ └── x86_64/
6. Automate Testing
#!/bin/bash
# test-cross.sh
# Build
./build-cross.sh
# Copy to target
scp build/myapp target:/tmp/
# Run on target
ssh target "/tmp/myapp --test"
# Check exit code
if [ $? -eq 0 ]; then
echo "Tests passed"
else
echo "Tests failed"
exit 1
fi
7. Document Dependencies
# dependencies.txt
Toolchain: gcc-arm-linux-gnueabihf-9.3
Libraries:
- libssl-dev:armhf (>= 1.1.1)
- libcurl4-openssl-dev:armhf (>= 7.68.0)
- zlib1g-dev:armhf
Kernel: 5.10 or later
Bootloader: U-Boot 2021.01
8. Optimize for Target
# Compiler optimizations
CFLAGS="-O2 -march=armv7-a -mtune=cortex-a8 -mfpu=neon"
# Size optimization
CFLAGS="-Os -ffunction-sections -fdata-sections"
LDFLAGS="-Wl,--gc-sections"
# Strip debug info for production
${CROSS_COMPILE}strip --strip-all myapp
Summary
Cross compilation is essential for embedded Linux development:
Key Steps:
- Install or build a cross-compilation toolchain
- Set
CROSS_COMPILEandARCHenvironment variables - Use
--sysrootor install target libraries on host - Build with cross compiler instead of native compiler
- Test on target device or QEMU emulator
Essential Variables:
CROSS_COMPILE: Toolchain prefix (e.g.,arm-linux-gnueabihf-)ARCH: Target architecture (e.g.,arm,arm64,mips)SYSROOT: Target root filesystem path on host
Common Workflows:
- Kernel:
make ARCH=arm CROSS_COMPILE=arm-linux-gnueabihf- - Autotools:
./configure --host=arm-linux-gnueabihf - CMake:
cmake -DCMAKE_TOOLCHAIN_FILE=arm-toolchain.cmake - Makefile:
make CROSS_COMPILE=arm-linux-gnueabihf-
Resources:
cfg80211 and mac80211
Linux wireless subsystem frameworks for 802.11 (WiFi) device drivers and configuration.
Table of Contents
- Overview
- Architecture
- cfg80211
- mac80211
- Driver Development
- nl80211
- Regulatory Framework
- Power Management
- Scanning
- Connection Management
- Mesh Networking
- Debugging
Overview
The Linux wireless stack consists of two main components:
- cfg80211: Configuration API and regulatory database for 802.11 devices
- mac80211: Generic IEEE 802.11 MAC layer implementation
Why Two Layers?
┌─────────────────────────────────────┐
│ User Space (iw, wpa_supplicant) │
└─────────────────────────────────────┘
│ nl80211
┌─────────────────────────────────────┐
│ cfg80211 │ ← Configuration & regulatory
│ (wireless configuration API) │
└─────────────────────────────────────┘
│
┌─────────────────────────────────────┐
│ mac80211 │ ← MAC layer (optional)
│ (generic MAC implementation) │
└─────────────────────────────────────┘
│
┌─────────────────────────────────────┐
│ WiFi Device Driver │ ← Hardware-specific
│ (ath9k, iwlwifi, rtl8xxxu, etc.) │
└─────────────────────────────────────┘
│
┌─────────────────────────────────────┐
│ Hardware (WiFi Chip) │
└─────────────────────────────────────┘
cfg80211 is mandatory for all wireless drivers. It provides:
- Configuration interface via nl80211
- Regulatory domain management
- Scanning coordination
- Authentication/association state machine
mac80211 is optional and provides a generic MAC layer implementation for devices that only implement hardware-specific functions (PHY layer). Drivers can choose to:
- Use mac80211 (most SoftMAC drivers: ath9k, iwlwifi, rtl8xxxu)
- Implement their own MAC (FullMAC drivers: brcmfmac, mwifiex)
Architecture
Layer Responsibilities
User Space
│
├─ iw: Configuration tool
├─ wpa_supplicant: WPA/WPA2 authentication
└─ hostapd: Access Point daemon
│
▼ nl80211 (netlink)
│
cfg80211
│
├─ Configuration API
├─ Regulatory database
├─ Scan results management
├─ Connection tracking
└─ nl80211 ↔ cfg80211_ops translation
│
▼ cfg80211_ops
│
mac80211 (optional)
│
├─ Beacon handling
├─ Power save
├─ Aggregation (A-MPDU/A-MSDU)
├─ Rate control
├─ TX/RX queuing
└─ Frame filtering
│
▼ ieee80211_ops
│
Driver (hardware-specific)
│
├─ Channel switching
├─ TX/RX DMA
├─ Interrupt handling
└─ Register access
│
▼
Hardware
Data Flow
TX Path:
Application
↓
Socket/Network Stack
↓
cfg80211 (for management frames)
↓
mac80211 (encryption, aggregation, queuing)
↓
Driver (DMA, hardware TX)
↓
Hardware
RX Path:
Hardware
↓
Driver (interrupt, DMA)
↓
mac80211 (decryption, defragmentation)
↓
cfg80211 (scan results, regulatory info)
↓
Network Stack
↓
Application
cfg80211
Core Concepts
cfg80211 is the configuration API for 802.11 devices. It abstracts hardware differences and provides a unified interface.
Key Data Structures
#include <net/cfg80211.h>
/* Wireless device (wiphy) - represents physical device */
struct wiphy {
int n_addresses;
struct mac_address *addresses;
/* Supported bands */
struct ieee80211_supported_band *bands[NUM_NL80211_BANDS];
/* Regulatory domain */
const struct ieee80211_regdomain *regd;
/* Driver callbacks */
const struct cfg80211_ops *ops;
/* Flags */
u32 flags;
/* Interface modes supported */
u16 interface_modes;
/* Cipher suites */
const u32 *cipher_suites;
int n_cipher_suites;
/* Maximum scan SSIDs */
u8 max_scan_ssids;
/* Maximum scheduled scan SSIDs */
u8 max_sched_scan_ssids;
/* Private driver data */
void *priv;
};
/* Wireless interface (wdev) - represents virtual interface */
struct wireless_dev {
struct wiphy *wiphy;
enum nl80211_iftype iftype;
struct net_device *netdev;
/* Current BSS */
struct cfg80211_bss *current_bss;
/* Connection parameters */
u8 ssid[IEEE80211_MAX_SSID_LEN];
u8 ssid_len;
/* Wireless extensions compatibility */
struct cfg80211_internal_bss *authtry_bsses[4];
struct cfg80211_internal_bss *auth_bsses[4];
struct cfg80211_internal_bss *assoc_bsses[4];
};
/* BSS information */
struct cfg80211_bss {
struct ieee80211_channel *channel;
u8 bssid[ETH_ALEN];
u64 tsf;
u16 beacon_interval;
u16 capability;
const u8 *ies;
size_t ies_len;
s32 signal;
u64 parent_tsf;
};
cfg80211_ops - Driver Callbacks
struct cfg80211_ops {
/* Interface management */
int (*add_virtual_intf)(struct wiphy *wiphy,
const char *name,
enum nl80211_iftype type,
struct vif_params *params);
int (*del_virtual_intf)(struct wiphy *wiphy,
struct wireless_dev *wdev);
int (*change_virtual_intf)(struct wiphy *wiphy,
struct net_device *dev,
enum nl80211_iftype type,
struct vif_params *params);
/* Scanning */
int (*scan)(struct wiphy *wiphy,
struct cfg80211_scan_request *request);
/* Connection */
int (*connect)(struct wiphy *wiphy,
struct net_device *dev,
struct cfg80211_connect_params *sme);
int (*disconnect)(struct wiphy *wiphy,
struct net_device *dev,
u16 reason_code);
/* Authentication & Association */
int (*auth)(struct wiphy *wiphy,
struct net_device *dev,
struct cfg80211_auth_request *req);
int (*assoc)(struct wiphy *wiphy,
struct net_device *dev,
struct cfg80211_assoc_request *req);
int (*deauth)(struct wiphy *wiphy,
struct net_device *dev,
struct cfg80211_deauth_request *req);
int (*disassoc)(struct wiphy *wiphy,
struct net_device *dev,
struct cfg80211_disassoc_request *req);
/* Configuration */
int (*set_channel)(struct wiphy *wiphy,
struct cfg80211_chan_def *chandef);
int (*set_txq_params)(struct wiphy *wiphy,
struct net_device *dev,
struct ieee80211_txq_params *params);
int (*set_tx_power)(struct wiphy *wiphy,
struct wireless_dev *wdev,
enum nl80211_tx_power_setting type,
int mbm);
int (*get_tx_power)(struct wiphy *wiphy,
struct wireless_dev *wdev,
int *dbm);
/* AP mode */
int (*start_ap)(struct wiphy *wiphy,
struct net_device *dev,
struct cfg80211_ap_settings *settings);
int (*stop_ap)(struct wiphy *wiphy,
struct net_device *dev);
/* Station management */
int (*add_station)(struct wiphy *wiphy,
struct net_device *dev,
const u8 *mac,
struct station_parameters *params);
int (*del_station)(struct wiphy *wiphy,
struct net_device *dev,
struct station_del_parameters *params);
int (*change_station)(struct wiphy *wiphy,
struct net_device *dev,
const u8 *mac,
struct station_parameters *params);
int (*get_station)(struct wiphy *wiphy,
struct net_device *dev,
const u8 *mac,
struct station_info *sinfo);
/* Power management */
int (*set_power_mgmt)(struct wiphy *wiphy,
struct net_device *dev,
bool enabled,
int timeout);
/* Regulatory */
void (*reg_notifier)(struct wiphy *wiphy,
struct regulatory_request *request);
};
Registering a Wiphy
#include <net/cfg80211.h>
static const struct cfg80211_ops my_cfg_ops = {
.scan = my_scan,
.connect = my_connect,
.disconnect = my_disconnect,
/* ... other callbacks ... */
};
static int my_probe(struct pci_dev *pdev, const struct pci_device_id *id)
{
struct wiphy *wiphy;
struct my_priv *priv;
int ret;
/* Allocate wiphy with private data */
wiphy = wiphy_new(&my_cfg_ops, sizeof(*priv));
if (!wiphy)
return -ENOMEM;
priv = wiphy_priv(wiphy);
/* Set wiphy parameters */
wiphy->interface_modes = BIT(NL80211_IFTYPE_STATION) |
BIT(NL80211_IFTYPE_AP);
wiphy->max_scan_ssids = 4;
wiphy->max_scan_ie_len = 256;
/* Set supported bands */
wiphy->bands[NL80211_BAND_2GHZ] = &my_band_2ghz;
wiphy->bands[NL80211_BAND_5GHZ] = &my_band_5ghz;
/* Set supported cipher suites */
wiphy->cipher_suites = my_cipher_suites;
wiphy->n_cipher_suites = ARRAY_SIZE(my_cipher_suites);
/* Set regulatory domain */
wiphy->regulatory_flags = REGULATORY_STRICT_REG;
/* Register wiphy */
ret = wiphy_register(wiphy);
if (ret) {
wiphy_free(wiphy);
return ret;
}
return 0;
}
static void my_remove(struct pci_dev *pdev)
{
struct wiphy *wiphy = pci_get_drvdata(pdev);
wiphy_unregister(wiphy);
wiphy_free(wiphy);
}
Band and Channel Definition
/* 2.4 GHz band channels */
static struct ieee80211_channel my_2ghz_channels[] = {
{ .band = NL80211_BAND_2GHZ, .center_freq = 2412, .hw_value = 1 },
{ .band = NL80211_BAND_2GHZ, .center_freq = 2417, .hw_value = 2 },
{ .band = NL80211_BAND_2GHZ, .center_freq = 2422, .hw_value = 3 },
/* ... channels 4-13 ... */
};
/* Supported rates for 2.4 GHz */
static struct ieee80211_rate my_2ghz_rates[] = {
{ .bitrate = 10 }, /* 1 Mbps */
{ .bitrate = 20 }, /* 2 Mbps */
{ .bitrate = 55 }, /* 5.5 Mbps */
{ .bitrate = 110 }, /* 11 Mbps */
{ .bitrate = 60 }, /* 6 Mbps */
{ .bitrate = 90 }, /* 9 Mbps */
{ .bitrate = 120 }, /* 12 Mbps */
/* ... more rates ... */
};
/* 2.4 GHz band definition */
static struct ieee80211_supported_band my_band_2ghz = {
.channels = my_2ghz_channels,
.n_channels = ARRAY_SIZE(my_2ghz_channels),
.bitrates = my_2ghz_rates,
.n_bitrates = ARRAY_SIZE(my_2ghz_rates),
.ht_cap = {
.cap = IEEE80211_HT_CAP_SGI_20 |
IEEE80211_HT_CAP_SGI_40 |
IEEE80211_HT_CAP_SUP_WIDTH_20_40,
.ht_supported = true,
},
};
mac80211
Overview
mac80211 is a framework for SoftMAC 802.11 drivers. It implements the MAC layer so drivers only need to implement hardware-specific operations.
Key Features
- Frame handling: Beacon, probe, authentication, association
- Encryption: WEP, TKIP, CCMP (AES)
- Power save: PS-Poll, U-APSD
- Aggregation: A-MPDU, A-MSDU
- Rate control: Minstrel, Minstrel HT
- Quality of Service: WMM/802.11e
- Block ACK: Aggregation acknowledgment
Core Data Structures
#include <net/mac80211.h>
/* Hardware structure */
struct ieee80211_hw {
struct ieee80211_conf conf;
struct wiphy *wiphy;
const char *rate_control_algorithm;
void *priv;
unsigned long flags;
/* Queues */
u16 queues;
u16 max_listen_interval;
s8 max_signal;
/* TX aggregation */
u8 max_rx_aggregation_subframes;
u8 max_tx_aggregation_subframes;
/* Offload capabilities */
u32 offchannel_tx_hw_queue;
netdev_features_t netdev_features;
};
/* Virtual interface (VIF) */
struct ieee80211_vif {
enum nl80211_iftype type;
struct ieee80211_bss_conf bss_conf;
u8 addr[ETH_ALEN];
bool p2p;
/* Driver private data */
u8 drv_priv[0] __aligned(sizeof(void *));
};
/* BSS configuration */
struct ieee80211_bss_conf {
u8 bssid[ETH_ALEN];
bool assoc;
u16 aid;
bool use_cts_prot;
bool use_short_preamble;
bool use_short_slot;
bool enable_beacon;
u16 beacon_int;
u8 dtim_period;
u32 basic_rates;
u32 beacon_rate;
struct ieee80211_p2p_noa_attr p2p_noa_attr;
};
/* Station information */
struct ieee80211_sta {
u8 addr[ETH_ALEN];
u16 aid;
u16 max_amsdu_len;
struct ieee80211_sta_ht_cap ht_cap;
struct ieee80211_sta_vht_cap vht_cap;
u8 max_sp;
u8 rx_nss;
/* Driver private data */
u8 drv_priv[0] __aligned(sizeof(void *));
};
/* TX info - attached to each TX skb */
struct ieee80211_tx_info {
u32 flags;
u8 band;
struct ieee80211_tx_rate rates[IEEE80211_TX_MAX_RATES];
union {
struct {
struct ieee80211_vif *vif;
struct ieee80211_key_conf *hw_key;
} control;
struct {
u64 cookie;
} ack;
struct {
struct ieee80211_tx_rate rates[IEEE80211_TX_MAX_RATES];
u8 ack_signal;
} status;
};
};
/* RX status - filled by driver */
struct ieee80211_rx_status {
u64 mactime;
u32 device_timestamp;
u16 flag;
u16 freq;
u8 rate_idx;
u8 vht_nss;
u8 rx_flags;
u8 band;
u8 antenna;
s8 signal;
u8 chains;
s8 chain_signal[IEEE80211_MAX_CHAINS];
};
ieee80211_ops - Driver Operations
struct ieee80211_ops {
/* Basic operations */
int (*start)(struct ieee80211_hw *hw);
void (*stop)(struct ieee80211_hw *hw);
/* Interface handling */
int (*add_interface)(struct ieee80211_hw *hw,
struct ieee80211_vif *vif);
void (*remove_interface)(struct ieee80211_hw *hw,
struct ieee80211_vif *vif);
/* Configuration */
int (*config)(struct ieee80211_hw *hw, u32 changed);
void (*bss_info_changed)(struct ieee80211_hw *hw,
struct ieee80211_vif *vif,
struct ieee80211_bss_conf *info,
u32 changed);
/* TX/RX */
void (*tx)(struct ieee80211_hw *hw,
struct ieee80211_tx_control *control,
struct sk_buff *skb);
int (*set_key)(struct ieee80211_hw *hw,
enum set_key_cmd cmd,
struct ieee80211_vif *vif,
struct ieee80211_sta *sta,
struct ieee80211_key_conf *key);
/* Scanning */
void (*sw_scan_start)(struct ieee80211_hw *hw,
struct ieee80211_vif *vif,
const u8 *mac_addr);
void (*sw_scan_complete)(struct ieee80211_hw *hw,
struct ieee80211_vif *vif);
int (*hw_scan)(struct ieee80211_hw *hw,
struct ieee80211_vif *vif,
struct ieee80211_scan_request *req);
/* Aggregation */
int (*ampdu_action)(struct ieee80211_hw *hw,
struct ieee80211_vif *vif,
struct ieee80211_ampdu_params *params);
/* Station management */
int (*sta_add)(struct ieee80211_hw *hw,
struct ieee80211_vif *vif,
struct ieee80211_sta *sta);
int (*sta_remove)(struct ieee80211_hw *hw,
struct ieee80211_vif *vif,
struct ieee80211_sta *sta);
void (*sta_notify)(struct ieee80211_hw *hw,
struct ieee80211_vif *vif,
enum sta_notify_cmd cmd,
struct ieee80211_sta *sta);
/* Power management */
int (*set_rts_threshold)(struct ieee80211_hw *hw, u32 value);
void (*set_coverage_class)(struct ieee80211_hw *hw, s16 coverage_class);
/* Multicast filter */
void (*configure_filter)(struct ieee80211_hw *hw,
unsigned int changed_flags,
unsigned int *total_flags,
u64 multicast);
};
Registering with mac80211
static const struct ieee80211_ops my_ops = {
.start = my_start,
.stop = my_stop,
.add_interface = my_add_interface,
.remove_interface = my_remove_interface,
.config = my_config,
.bss_info_changed = my_bss_info_changed,
.tx = my_tx,
.set_key = my_set_key,
/* ... */
};
static int my_probe(struct pci_dev *pdev, const struct pci_device_id *id)
{
struct ieee80211_hw *hw;
struct my_priv *priv;
int ret;
/* Allocate hardware structure */
hw = ieee80211_alloc_hw(sizeof(*priv), &my_ops);
if (!hw)
return -ENOMEM;
priv = hw->priv;
priv->pdev = pdev;
/* Set hardware capabilities */
hw->flags = IEEE80211_HW_SIGNAL_DBM |
IEEE80211_HW_AMPDU_AGGREGATION |
IEEE80211_HW_SUPPORTS_PS |
IEEE80211_HW_MFP_CAPABLE;
hw->queues = 4; /* Number of TX queues */
hw->max_rates = 4;
hw->max_rate_tries = 7;
/* Set channel bands */
hw->wiphy->bands[NL80211_BAND_2GHZ] = &my_band_2ghz;
hw->wiphy->bands[NL80211_BAND_5GHZ] = &my_band_5ghz;
/* Set supported interface modes */
hw->wiphy->interface_modes =
BIT(NL80211_IFTYPE_STATION) |
BIT(NL80211_IFTYPE_AP) |
BIT(NL80211_IFTYPE_P2P_CLIENT) |
BIT(NL80211_IFTYPE_P2P_GO);
/* Register hardware */
ret = ieee80211_register_hw(hw);
if (ret) {
ieee80211_free_hw(hw);
return ret;
}
return 0;
}
TX Path Implementation
static void my_tx(struct ieee80211_hw *hw,
struct ieee80211_tx_control *control,
struct sk_buff *skb)
{
struct my_priv *priv = hw->priv;
struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb);
struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data;
/* Get TX rate from mac80211 rate control */
u8 rate_idx = info->control.rates[0].idx;
/* Determine hardware queue */
u8 queue = skb_get_queue_mapping(skb);
/* Add hardware-specific TX descriptor */
struct my_tx_desc *desc = (struct my_tx_desc *)skb_push(skb, sizeof(*desc));
memset(desc, 0, sizeof(*desc));
desc->rate = rate_idx;
desc->retry_limit = info->control.rates[0].count;
/* Handle encryption if needed */
if (info->control.hw_key) {
/* Hardware encryption */
desc->key_idx = info->control.hw_key->hw_key_idx;
desc->flags |= TX_FLAGS_ENCRYPT;
}
/* Submit to hardware TX queue */
spin_lock_bh(&priv->tx_lock);
if (my_tx_queue_full(priv, queue)) {
/* Queue full, stop mac80211 queue */
ieee80211_stop_queue(hw, queue);
spin_unlock_bh(&priv->tx_lock);
dev_kfree_skb_any(skb);
return;
}
/* Add to DMA ring */
my_tx_add_to_ring(priv, queue, skb);
/* Kick hardware */
my_tx_kick(priv, queue);
spin_unlock_bh(&priv->tx_lock);
}
/* TX completion interrupt handler */
static void my_tx_complete(struct my_priv *priv)
{
struct ieee80211_hw *hw = priv->hw;
struct sk_buff *skb;
struct ieee80211_tx_info *info;
u8 queue;
while ((skb = my_get_completed_frame(priv, &queue))) {
info = IEEE80211_SKB_CB(skb);
/* Fill in TX status */
if (my_tx_was_successful(skb)) {
info->flags |= IEEE80211_TX_STAT_ACK;
}
/* Remove hardware TX descriptor */
skb_pull(skb, sizeof(struct my_tx_desc));
/* Report to mac80211 */
ieee80211_tx_status(hw, skb);
/* Wake queue if needed */
if (ieee80211_queue_stopped(hw, queue))
ieee80211_wake_queue(hw, queue);
}
}
RX Path Implementation
static void my_rx_tasklet(unsigned long data)
{
struct my_priv *priv = (struct my_priv *)data;
struct ieee80211_hw *hw = priv->hw;
struct sk_buff *skb;
struct ieee80211_rx_status *rx_status;
struct my_rx_desc *desc;
while ((skb = my_get_rx_frame(priv))) {
desc = (struct my_rx_desc *)skb->data;
/* Allocate rx_status */
rx_status = IEEE80211_SKB_RXCB(skb);
memset(rx_status, 0, sizeof(*rx_status));
/* Fill in RX status from hardware descriptor */
rx_status->freq = ieee80211_channel_to_frequency(
desc->channel,
NL80211_BAND_2GHZ);
rx_status->band = NL80211_BAND_2GHZ;
rx_status->signal = desc->rssi;
rx_status->rate_idx = desc->rate;
rx_status->antenna = desc->antenna;
/* Set flags */
if (desc->flags & RX_FLAG_SHORT_PREAMBLE)
rx_status->flag |= RX_FLAG_SHORTPRE;
if (desc->flags & RX_FLAG_DECRYPTED) {
rx_status->flag |= RX_FLAG_DECRYPTED;
rx_status->flag |= RX_FLAG_IV_STRIPPED;
rx_status->flag |= RX_FLAG_MMIC_STRIPPED;
}
/* Remove hardware RX descriptor */
skb_pull(skb, sizeof(*desc));
/* Pass to mac80211 */
ieee80211_rx(hw, skb);
}
}
Driver Development
FullMAC Driver Example
FullMAC drivers implement their own MAC and only use cfg80211.
#include <net/cfg80211.h>
/* FullMAC driver - implements own MAC */
static int my_fullmac_scan(struct wiphy *wiphy,
struct cfg80211_scan_request *request)
{
struct my_priv *priv = wiphy_priv(wiphy);
int i;
/* Send scan command to firmware */
for (i = 0; i < request->n_ssids; i++) {
my_fw_scan_ssid(priv,
request->ssids[i].ssid,
request->ssids[i].ssid_len);
}
for (i = 0; i < request->n_channels; i++) {
my_fw_scan_channel(priv,
request->channels[i]->center_freq);
}
my_fw_start_scan(priv);
return 0;
}
/* Firmware event: scan result */
static void my_handle_scan_result(struct my_priv *priv,
struct my_scan_result *result)
{
struct wiphy *wiphy = priv->wiphy;
struct cfg80211_bss *bss;
struct ieee80211_channel *channel;
struct cfg80211_inform_bss data = {};
channel = ieee80211_get_channel(wiphy, result->frequency);
if (!channel)
return;
/* Inform cfg80211 about BSS */
bss = cfg80211_inform_bss_data(
wiphy,
&data,
CFG80211_BSS_FTYPE_UNKNOWN,
result->bssid,
result->tsf,
result->capability,
result->beacon_interval,
result->ie,
result->ie_len,
result->signal,
GFP_KERNEL);
cfg80211_put_bss(wiphy, bss);
}
/* Firmware event: scan complete */
static void my_handle_scan_complete(struct my_priv *priv)
{
struct cfg80211_scan_info info = {
.aborted = false,
};
cfg80211_scan_done(priv->scan_request, &info);
priv->scan_request = NULL;
}
/* Connect */
static int my_fullmac_connect(struct wiphy *wiphy,
struct net_device *dev,
struct cfg80211_connect_params *sme)
{
struct my_priv *priv = wiphy_priv(wiphy);
/* Send connect command to firmware */
my_fw_connect(priv,
sme->ssid, sme->ssid_len,
sme->bssid,
sme->channel,
sme->auth_type);
return 0;
}
/* Firmware event: connected */
static void my_handle_connected(struct my_priv *priv)
{
cfg80211_connect_result(priv->dev,
priv->bssid,
NULL, 0,
NULL, 0,
WLAN_STATUS_SUCCESS,
GFP_KERNEL);
}
/* Firmware event: disconnected */
static void my_handle_disconnected(struct my_priv *priv, u16 reason)
{
cfg80211_disconnected(priv->dev, reason, NULL, 0, true, GFP_KERNEL);
}
SoftMAC Driver Example
SoftMAC drivers use mac80211 for MAC implementation.
#include <net/mac80211.h>
static int my_softmac_start(struct ieee80211_hw *hw)
{
struct my_priv *priv = hw->priv;
/* Power on hardware */
my_hw_power_on(priv);
/* Load firmware if needed */
my_load_firmware(priv);
/* Initialize hardware */
my_hw_init(priv);
/* Enable interrupts */
my_enable_interrupts(priv);
return 0;
}
static void my_softmac_stop(struct ieee80211_hw *hw)
{
struct my_priv *priv = hw->priv;
/* Disable interrupts */
my_disable_interrupts(priv);
/* Shutdown hardware */
my_hw_shutdown(priv);
/* Power off */
my_hw_power_off(priv);
}
static int my_softmac_add_interface(struct ieee80211_hw *hw,
struct ieee80211_vif *vif)
{
struct my_priv *priv = hw->priv;
/* Set MAC address */
my_hw_set_mac_address(priv, vif->addr);
/* Set interface type */
switch (vif->type) {
case NL80211_IFTYPE_STATION:
my_hw_set_mode(priv, MODE_STA);
break;
case NL80211_IFTYPE_AP:
my_hw_set_mode(priv, MODE_AP);
break;
default:
return -EOPNOTSUPP;
}
return 0;
}
static void my_softmac_bss_info_changed(struct ieee80211_hw *hw,
struct ieee80211_vif *vif,
struct ieee80211_bss_conf *info,
u32 changed)
{
struct my_priv *priv = hw->priv;
if (changed & BSS_CHANGED_BSSID) {
/* BSSID changed */
my_hw_set_bssid(priv, info->bssid);
}
if (changed & BSS_CHANGED_ASSOC) {
if (info->assoc) {
/* Associated */
my_hw_set_associated(priv, true);
my_hw_set_aid(priv, info->aid);
} else {
/* Disassociated */
my_hw_set_associated(priv, false);
}
}
if (changed & BSS_CHANGED_BEACON_INT) {
/* Beacon interval changed */
my_hw_set_beacon_interval(priv, info->beacon_int);
}
if (changed & BSS_CHANGED_ERP_CTS_PROT) {
/* CTS protection changed */
my_hw_set_cts_protection(priv, info->use_cts_prot);
}
if (changed & BSS_CHANGED_ERP_SLOT) {
/* Slot time changed */
my_hw_set_short_slot(priv, info->use_short_slot);
}
}
static int my_softmac_config(struct ieee80211_hw *hw, u32 changed)
{
struct my_priv *priv = hw->priv;
struct ieee80211_conf *conf = &hw->conf;
if (changed & IEEE80211_CONF_CHANGE_CHANNEL) {
/* Channel changed */
struct ieee80211_channel *chan = conf->chandef.chan;
my_hw_set_channel(priv, chan->center_freq);
}
if (changed & IEEE80211_CONF_CHANGE_POWER) {
/* TX power changed */
my_hw_set_tx_power(priv, conf->power_level);
}
if (changed & IEEE80211_CONF_CHANGE_IDLE) {
/* Idle state changed */
if (conf->flags & IEEE80211_CONF_IDLE)
my_hw_enter_idle(priv);
else
my_hw_exit_idle(priv);
}
return 0;
}
nl80211
nl80211 is the netlink-based configuration interface for wireless devices.
User Space Tools
# iw - nl80211 configuration utility
# List wireless devices
iw dev
# Scan for networks
iw dev wlan0 scan
# Connect to network
iw dev wlan0 connect MyNetwork
# Set channel
iw dev wlan0 set channel 6
# Set TX power
iw dev wlan0 set txpower fixed 2000 # 20 dBm
# Create AP
iw dev wlan0 set type __ap
ip link set wlan0 up
iw dev wlan0 set channel 6
# Monitor mode
iw dev wlan0 set type monitor
ip link set wlan0 up
# Station info
iw dev wlan0 station dump
# Link statistics
iw dev wlan0 link
# Survey (channel usage)
iw dev wlan0 survey dump
nl80211 in Code
#include <net/nl80211.h>
/* User space typically uses libnl */
#include <netlink/netlink.h>
#include <netlink/genl/genl.h>
#include <netlink/genl/ctrl.h>
/* Send scan request */
static int nl80211_scan(const char *ifname)
{
struct nl_sock *sk;
struct nl_msg *msg;
int ret, family_id;
sk = nl_socket_alloc();
genl_connect(sk);
family_id = genl_ctrl_resolve(sk, "nl80211");
msg = nlmsg_alloc();
genlmsg_put(msg, 0, 0, family_id, 0, 0, NL80211_CMD_TRIGGER_SCAN, 0);
nla_put_u32(msg, NL80211_ATTR_IFINDEX, if_nametoindex(ifname));
ret = nl_send_auto(sk, msg);
nlmsg_free(msg);
nl_socket_free(sk);
return ret;
}
Regulatory Framework
The regulatory framework enforces regional wireless regulations.
Regulatory Database
/* Regulatory domain definition */
static const struct ieee80211_regdomain my_regdom = {
.n_reg_rules = 2,
.alpha2 = "US",
.reg_rules = {
/* 2.4 GHz */
REG_RULE(2412-10, 2462+10, 40, 6, 20, 0),
/* 5 GHz */
REG_RULE(5180-10, 5320+10, 160, 6, 23, 0),
}
};
/* Set regulatory domain */
static void my_set_regdom(struct wiphy *wiphy)
{
regulatory_hint(wiphy, "US");
}
/* Regulatory notifier */
static void my_reg_notifier(struct wiphy *wiphy,
struct regulatory_request *request)
{
struct my_priv *priv = wiphy_priv(wiphy);
pr_info("Regulatory domain: %c%c\n",
request->alpha2[0], request->alpha2[1]);
/* Update hardware with new regulatory settings */
my_hw_update_regulatory(priv, request);
}
Country IE Handling
/* Parse country IE from beacon */
static void my_parse_country_ie(struct my_priv *priv,
const u8 *country_ie, size_t len)
{
char alpha2[2];
struct ieee80211_regdomain *rd;
if (len < 6)
return;
/* Extract country code */
alpha2[0] = country_ie[0];
alpha2[1] = country_ie[1];
/* Hint regulatory domain */
regulatory_hint(priv->wiphy, alpha2);
}
Power Management
Station Power Save
/* Enable power save */
static int my_set_power_mgmt(struct wiphy *wiphy,
struct net_device *dev,
bool enabled, int timeout)
{
struct my_priv *priv = wiphy_priv(wiphy);
if (enabled) {
my_hw_enable_power_save(priv);
my_hw_set_ps_timeout(priv, timeout);
} else {
my_hw_disable_power_save(priv);
}
return 0;
}
/* Handle beacon from AP (in power save mode) */
static void my_handle_beacon(struct my_priv *priv, struct sk_buff *skb)
{
struct ieee80211_mgmt *mgmt = (void *)skb->data;
u8 *tim_ie;
bool has_buffered;
/* Find TIM IE */
tim_ie = my_find_ie(mgmt->u.beacon.variable,
skb->len - offsetof(struct ieee80211_mgmt,
u.beacon.variable),
WLAN_EID_TIM);
if (!tim_ie)
return;
/* Check if AP has buffered frames */
has_buffered = my_check_tim(tim_ie, priv->aid);
if (has_buffered) {
/* Send PS-Poll to retrieve frames */
my_send_pspoll(priv);
}
}
AP Power Save
/* Client entered power save */
static void my_sta_ps_start(struct my_priv *priv, struct ieee80211_sta *sta)
{
/* Mark station as sleeping */
set_sta_flag(sta, WLAN_STA_PS_STA);
/* Queue frames instead of transmitting */
}
/* Client exited power save */
static void my_sta_ps_end(struct my_priv *priv, struct ieee80211_sta *sta)
{
/* Mark station as awake */
clear_sta_flag(sta, WLAN_STA_PS_STA);
/* Transmit buffered frames */
my_deliver_buffered_frames(priv, sta);
}
Scanning
Active Scan
/* Send probe request */
static void my_send_probe_req(struct my_priv *priv,
const u8 *ssid, size_t ssid_len,
u32 freq)
{
struct sk_buff *skb;
struct ieee80211_mgmt *mgmt;
u8 *pos;
skb = dev_alloc_skb(200);
mgmt = (struct ieee80211_mgmt *)skb_put(skb,
offsetof(struct ieee80211_mgmt, u.probe_req.variable));
/* Fill in header */
mgmt->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT |
IEEE80211_STYPE_PROBE_REQ);
eth_broadcast_addr(mgmt->da);
memcpy(mgmt->sa, priv->mac_addr, ETH_ALEN);
eth_broadcast_addr(mgmt->bssid);
/* Add SSID IE */
pos = skb_put(skb, 2 + ssid_len);
*pos++ = WLAN_EID_SSID;
*pos++ = ssid_len;
memcpy(pos, ssid, ssid_len);
/* Add supported rates IE */
/* ... */
/* Transmit */
my_tx_mgmt_frame(priv, skb, freq);
}
Passive Scan
/* Listen for beacons on channel */
static void my_passive_scan_channel(struct my_priv *priv, u32 freq)
{
/* Switch to channel */
my_hw_set_channel(priv, freq);
/* Wait for beacons (typically 100-200ms per channel) */
msleep(100);
/* Process received beacons in RX handler */
}
Connection Management
Station Connection Flow
/* 1. Authentication */
static int my_authenticate(struct my_priv *priv,
const u8 *bssid,
enum nl80211_auth_type auth_type)
{
struct sk_buff *skb;
struct ieee80211_mgmt *mgmt;
skb = dev_alloc_skb(256);
mgmt = (struct ieee80211_mgmt *)skb_put(skb,
offsetof(struct ieee80211_mgmt, u.auth.variable));
mgmt->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT |
IEEE80211_STYPE_AUTH);
memcpy(mgmt->da, bssid, ETH_ALEN);
memcpy(mgmt->sa, priv->mac_addr, ETH_ALEN);
memcpy(mgmt->bssid, bssid, ETH_ALEN);
mgmt->u.auth.auth_alg = cpu_to_le16(auth_type);
mgmt->u.auth.auth_transaction = cpu_to_le16(1);
mgmt->u.auth.status_code = 0;
my_tx_mgmt_frame(priv, skb, priv->channel_freq);
return 0;
}
/* 2. Handle authentication response */
static void my_handle_auth_resp(struct my_priv *priv, struct sk_buff *skb)
{
struct ieee80211_mgmt *mgmt = (void *)skb->data;
u16 status = le16_to_cpu(mgmt->u.auth.status_code);
if (status == WLAN_STATUS_SUCCESS) {
/* Authenticated, proceed to association */
cfg80211_tx_mlme_mgmt(priv->dev, skb->data, skb->len);
my_associate(priv, mgmt->bssid);
} else {
cfg80211_tx_mlme_mgmt(priv->dev, skb->data, skb->len);
}
}
/* 3. Association */
static int my_associate(struct my_priv *priv, const u8 *bssid)
{
struct sk_buff *skb;
struct ieee80211_mgmt *mgmt;
u8 *pos;
skb = dev_alloc_skb(512);
mgmt = (struct ieee80211_mgmt *)skb_put(skb,
offsetof(struct ieee80211_mgmt, u.assoc_req.variable));
mgmt->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT |
IEEE80211_STYPE_ASSOC_REQ);
memcpy(mgmt->da, bssid, ETH_ALEN);
memcpy(mgmt->sa, priv->mac_addr, ETH_ALEN);
memcpy(mgmt->bssid, bssid, ETH_ALEN);
mgmt->u.assoc_req.capab_info = cpu_to_le16(WLAN_CAPABILITY_ESS);
mgmt->u.assoc_req.listen_interval = cpu_to_le16(10);
pos = mgmt->u.assoc_req.variable;
/* Add SSID IE */
/* Add supported rates IE */
/* Add HT capabilities IE */
/* Add VHT capabilities IE */
/* ... */
my_tx_mgmt_frame(priv, skb, priv->channel_freq);
return 0;
}
/* 4. Handle association response */
static void my_handle_assoc_resp(struct my_priv *priv, struct sk_buff *skb)
{
struct ieee80211_mgmt *mgmt = (void *)skb->data;
u16 status = le16_to_cpu(mgmt->u.assoc_resp.status_code);
u16 aid = le16_to_cpu(mgmt->u.assoc_resp.aid);
if (status == WLAN_STATUS_SUCCESS) {
priv->aid = aid & 0x3fff;
cfg80211_connect_result(priv->dev,
mgmt->bssid,
NULL, 0, NULL, 0,
status, GFP_KERNEL);
} else {
cfg80211_connect_result(priv->dev,
mgmt->bssid,
NULL, 0, NULL, 0,
status, GFP_KERNEL);
}
}
Mesh Networking
/* Start mesh interface */
static int my_join_mesh(struct wiphy *wiphy,
struct net_device *dev,
const struct mesh_config *conf,
const struct mesh_setup *setup)
{
struct my_priv *priv = wiphy_priv(wiphy);
/* Set mesh ID */
memcpy(priv->mesh_id, setup->mesh_id, setup->mesh_id_len);
priv->mesh_id_len = setup->mesh_id_len;
/* Enable mesh mode in hardware */
my_hw_enable_mesh(priv);
/* Start beaconing */
my_start_mesh_beaconing(priv);
return 0;
}
/* Handle mesh peering */
static void my_mesh_peer_open(struct my_priv *priv,
const u8 *peer_addr)
{
/* Send peer link open frame */
my_send_mesh_peering_frame(priv, peer_addr,
MESH_PEERING_OPEN);
}
Debugging
Enable cfg80211 Debug
# Enable cfg80211 debug messages
echo 'module cfg80211 +p' > /sys/kernel/debug/dynamic_debug/control
# Or at boot
cfg80211.debug=0xffffffff
Enable mac80211 Debug
# Enable mac80211 debug
echo 'module mac80211 +p' > /sys/kernel/debug/dynamic_debug/control
# Or at module load
modprobe mac80211 debug=0xffffffff
# Debug categories (bitfield):
# 0x00000001 - INFO
# 0x00000002 - PS (power save)
# 0x00000004 - HT (high throughput)
# 0x00000008 - TX status
Driver Debug
/* Use dev_dbg for driver messages */
dev_dbg(&pdev->dev, "Channel: %d, Freq: %d\n", channel, freq);
/* Conditional debugging */
#ifdef DEBUG
#define my_dbg(fmt, ...) pr_debug(fmt, ##__VA_ARGS__)
#else
#define my_dbg(fmt, ...) no_printk(fmt, ##__VA_ARGS__)
#endif
/* Rate control debugging */
#ifdef CONFIG_MAC80211_RC_MINSTREL_DEBUGFS
/* Rate stats available in debugfs */
/* /sys/kernel/debug/ieee80211/phyX/netdev:wlanX/stations/<MAC>/rc_stats */
#endif
Useful debugfs Entries
# List all wireless devices
ls /sys/kernel/debug/ieee80211/
# Per-PHY info
cat /sys/kernel/debug/ieee80211/phy0/hwflags
cat /sys/kernel/debug/ieee80211/phy0/queues
# Per-netdev info
ls /sys/kernel/debug/ieee80211/phy0/netdev:wlan0/
# Station info
ls /sys/kernel/debug/ieee80211/phy0/netdev:wlan0/stations/
# Rate control stats
cat /sys/kernel/debug/ieee80211/phy0/netdev:wlan0/stations/<MAC>/rc_stats
# Reset stats
echo 1 > /sys/kernel/debug/ieee80211/phy0/reset
Packet Capture
# Monitor mode for packet capture
iw dev wlan0 set type monitor
ip link set wlan0 up
iw dev wlan0 set channel 6
# Capture with tcpdump
tcpdump -i wlan0 -w capture.pcap
# Or with wireshark
wireshark -i wlan0 -k
Best Practices
Driver Development
- Use mac80211 when possible: Unless hardware has a full MAC, use mac80211
- Implement all required callbacks: Check return values
- Handle errors gracefully: Don't crash the kernel
- Test with multiple APs: Different vendors, security types
- Support monitor mode: Essential for debugging
- Implement regulatory: Country codes, power limits
- Handle race conditions: Use proper locking
- Clean up resources: On errors and removal
Performance
- Enable hardware offloads: Encryption, aggregation
- Use DMA efficiently: Minimize CPU involvement
- Implement rate control: Or use mac80211's minstrel
- Support A-MPDU/A-MSDU: For high throughput
- Optimize interrupt handling: Use NAPI if possible
- Enable power save: For battery-powered devices
Security
- Never trust user input: Validate all parameters
- Handle untrusted frames: Check lengths, types
- Implement hardware encryption: When available
- Support WPA3: Modern security standards
- Protect management frames: 802.11w (PMF)
Resources
- Kernel Documentation:
Documentation/networking/mac80211.rst - cfg80211 header:
include/net/cfg80211.h - mac80211 header:
include/net/mac80211.h - nl80211 header:
include/uapi/linux/nl80211.h - Example drivers:
drivers/net/wireless/ath/ath9k/- mac80211 driverbroadcom/brcm80211/brcmfmac/- FullMAC driverintel/iwlwifi/- Advanced mac80211 driver
- iw tool source: https://git.kernel.org/pub/scm/linux/kernel/git/jberg/iw.git
- Regulatory database: https://git.kernel.org/pub/scm/linux/kernel/git/sforshee/wireless-regdb.git
cfg80211 and mac80211 provide a robust framework for wireless driver development in Linux, handling much of the complex 802.11 protocol logic so drivers can focus on hardware-specific operations.
eBPF (Extended Berkeley Packet Filter)
Table of Contents
- Introduction
- Architecture
- Program Types
- eBPF Maps
- Development Tools
- Writing eBPF Programs
- Common Use Cases
- Examples
- Security and Safety
- Debugging
- Resources
Introduction
What is eBPF?
eBPF (Extended Berkeley Packet Filter) is a revolutionary Linux kernel technology that allows running sandboxed programs in kernel space without changing kernel source code or loading kernel modules. It enables dynamic extension of kernel capabilities for networking, observability, security, and performance analysis.
History
- 1992: Original BPF (Berkeley Packet Filter) created for packet filtering in BSD
- 2014: eBPF introduced in Linux kernel 3.18, extending BPF beyond networking
- 2016-Present: Rapid evolution with new program types, maps, and helper functions
Key Features
- Safe: Verifier ensures programs are safe to run in kernel space
- Efficient: JIT compilation for native performance
- Dynamic: Load/unload programs without rebooting
- Programmable: Write custom kernel extensions in C/Rust
- Event-driven: Attach to kernel/user events without overhead when not triggered
Use Cases
- Network packet filtering and manipulation
- Performance monitoring and profiling
- Security enforcement and runtime protection
- Tracing and observability
- Load balancing and service mesh
- Container networking
Architecture
eBPF Virtual Machine
eBPF programs run in a virtual machine within the kernel with:
- 11 64-bit registers (R0-R10)
- 512-byte stack
- RISC-like instruction set (similar to x86-64)
- Bounded loops (since kernel 5.3)
R0: Return value from functions/exit value
R1-R5: Function arguments
R6-R9: Callee-saved registers
R10: Read-only frame pointer
Core Components
1. Verifier
- Static analysis of eBPF bytecode before loading
- Ensures memory safety (no out-of-bounds access)
- Validates control flow (no infinite loops, reachable code)
- Checks register states and types
- Limits program complexity
2. JIT Compiler
- Compiles eBPF bytecode to native machine code
- Available for x86-64, ARM64, RISC-V, etc.
- Provides near-native performance
- Can be disabled (interpreter fallback)
# Enable JIT compiler
echo 1 > /proc/sys/net/core/bpf_jit_enable
# Enable JIT debug (dump compiled code)
echo 2 > /proc/sys/net/core/bpf_jit_enable
3. Helper Functions
- Kernel functions callable from eBPF programs
- Type-safe interfaces to kernel functionality
- Examples: map operations, packet manipulation, time functions
4. Maps
- Data structures for sharing data between eBPF programs and user space
- Persistent storage across program invocations
- Various types: hash, array, ring buffer, etc.
Attachment Points (Hooks)
eBPF programs attach to kernel events:
- Network: XDP, TC, socket operations, cgroups
- Tracing: kprobes, uprobes, tracepoints, USDT
- Security: LSM hooks, seccomp
- Cgroups: Device access, socket operations, sysctl
Program Types
XDP (eXpress Data Path)
Processes packets at the earliest point in the network stack (driver level).
Use Cases: DDoS mitigation, load balancing, packet filtering
Return Codes:
XDP_DROP: Drop packetXDP_PASS: Pass to network stackXDP_TX: Bounce packet back out same interfaceXDP_REDIRECT: Redirect to another interfaceXDP_ABORTED: Error, drop packet
Example Hook:
SEC("xdp")
int xdp_prog(struct xdp_md *ctx) {
// Access packet data
void *data_end = (void *)(long)ctx->data_end;
void *data = (void *)(long)ctx->data;
// Process packet
return XDP_PASS;
}
TC (Traffic Control)
Attaches to network queueing discipline (ingress/egress).
Use Cases: QoS, traffic shaping, packet modification
Attachment:
tc qdisc add dev eth0 clsact
tc filter add dev eth0 ingress bpf da obj prog.o sec classifier
Tracepoints
Static instrumentation points in the kernel.
Advantages: Stable ABI, defined arguments Locations: Scheduling, system calls, network events
SEC("tracepoint/syscalls/sys_enter_execve")
int trace_execve(struct trace_event_raw_sys_enter *ctx) {
// Trace execve system call
return 0;
}
Kprobes/Kretprobes
Dynamic instrumentation of any kernel function.
Kprobe: Execute at function entry Kretprobe: Execute at function return
SEC("kprobe/tcp_connect")
int trace_tcp_connect(struct pt_regs *ctx) {
// Hook tcp_connect function
return 0;
}
SEC("kretprobe/tcp_connect")
int trace_tcp_connect_ret(struct pt_regs *ctx) {
// Get return value
int ret = PT_REGS_RC(ctx);
return 0;
}
Uprobes/Uretprobes
Dynamic instrumentation of user-space functions.
Use Cases: Application profiling, library tracing
SEC("uprobe/usr/lib/libc.so.6:malloc")
int trace_malloc(struct pt_regs *ctx) {
size_t size = PT_REGS_PARM1(ctx);
return 0;
}
Socket Filters
Filter and process socket data.
Types:
BPF_PROG_TYPE_SOCKET_FILTER: Classic socket filteringBPF_PROG_TYPE_SOCK_OPS: Socket operations monitoringBPF_PROG_TYPE_SK_SKB: Socket buffer redirectionBPF_PROG_TYPE_SK_MSG: Socket message filtering
LSM (Linux Security Module)
Implement security policies using LSM hooks.
Requirements: Kernel 5.7+, BPF LSM enabled
SEC("lsm/file_open")
int BPF_PROG(file_open, struct file *file) {
// Implement access control
return 0; // Allow
}
Other Program Types
- Cgroup programs: Control resource access per cgroup
- Perf event: Attach to performance monitoring events
- Raw tracepoints: Low-overhead tracing
- BTF-enabled programs: Type information for portability
eBPF Maps
Maps are key-value data structures for storing state and communicating between eBPF programs and user space.
Map Types
BPF_MAP_TYPE_HASH
Hash table for arbitrary key-value pairs.
struct {
__uint(type, BPF_MAP_TYPE_HASH);
__uint(max_entries, 10000);
__type(key, u32);
__type(value, u64);
} my_hash_map SEC(".maps");
BPF_MAP_TYPE_ARRAY
Fixed-size array indexed by integer.
struct {
__uint(type, BPF_MAP_TYPE_ARRAY);
__uint(max_entries, 256);
__type(key, u32);
__type(value, u64);
} my_array SEC(".maps");
BPF_MAP_TYPE_PERCPU_HASH / PERCPU_ARRAY
Per-CPU variants for better performance (no locking).
struct {
__uint(type, BPF_MAP_TYPE_PERCPU_ARRAY);
__uint(max_entries, 256);
__type(key, u32);
__type(value, u64);
} percpu_stats SEC(".maps");
BPF_MAP_TYPE_RINGBUF
Ring buffer for efficient kernel-to-user data streaming (kernel 5.8+).
struct {
__uint(type, BPF_MAP_TYPE_RINGBUF);
__uint(max_entries, 256 * 1024);
} events SEC(".maps");
// Reserve and submit
struct event *e = bpf_ringbuf_reserve(&events, sizeof(*e), 0);
if (e) {
e->pid = bpf_get_current_pid_tgid() >> 32;
bpf_ringbuf_submit(e, 0);
}
BPF_MAP_TYPE_PERF_EVENT_ARRAY
Per-CPU event buffers (older than ringbuf).
struct {
__uint(type, BPF_MAP_TYPE_PERF_EVENT_ARRAY);
__uint(key_size, sizeof(u32));
__uint(value_size, sizeof(u32));
} events SEC(".maps");
BPF_MAP_TYPE_LRU_HASH
Hash table with Least Recently Used eviction.
struct {
__uint(type, BPF_MAP_TYPE_LRU_HASH);
__uint(max_entries, 10000);
__type(key, u32);
__type(value, u64);
} lru_cache SEC(".maps");
BPF_MAP_TYPE_STACK_TRACE
Store stack traces.
struct {
__uint(type, BPF_MAP_TYPE_STACK_TRACE);
__uint(max_entries, 1000);
__type(key, u32);
__type(value, u64[127]);
} stack_traces SEC(".maps");
BPF_MAP_TYPE_PROG_ARRAY
Array of eBPF programs for tail calls.
struct {
__uint(type, BPF_MAP_TYPE_PROG_ARRAY);
__uint(max_entries, 10);
__type(key, u32);
__type(value, u32);
} prog_array SEC(".maps");
// Tail call
bpf_tail_call(ctx, &prog_array, index);
Map Operations
// Lookup
value = bpf_map_lookup_elem(&my_map, &key);
// Update
bpf_map_update_elem(&my_map, &key, &value, BPF_ANY);
// Delete
bpf_map_delete_elem(&my_map, &key);
Update Flags:
BPF_ANY: Create or updateBPF_NOEXIST: Create only if doesn't existBPF_EXIST: Update only if exists
Development Tools
BCC (BPF Compiler Collection)
Python/Lua framework for writing eBPF programs.
Pros: High-level, rapid development, many examples Cons: Runtime compilation, LLVM dependency on target
from bcc import BPF
prog = """
int hello(void *ctx) {
bpf_trace_printk("Hello, World!\\n");
return 0;
}
"""
b = BPF(text=prog)
b.attach_kprobe(event="sys_clone", fn_name="hello")
libbpf
C library for loading and managing eBPF programs.
Pros: No runtime dependencies, CO-RE support, production-ready Cons: Lower-level, more boilerplate
struct bpf_object *obj;
struct bpf_program *prog;
struct bpf_link *link;
obj = bpf_object__open_file("prog.o", NULL);
bpf_object__load(obj);
prog = bpf_object__find_program_by_name(obj, "xdp_prog");
link = bpf_program__attach(prog);
bpftool
Command-line tool for inspecting and managing eBPF programs/maps.
# List programs
bpftool prog list
# Show program details
bpftool prog show id 123
# Dump program bytecode
bpftool prog dump xlated id 123
# List maps
bpftool map list
# Dump map contents
bpftool map dump id 456
# Load program
bpftool prog load prog.o /sys/fs/bpf/myprog
# Pin map
bpftool map pin id 456 /sys/fs/bpf/mymap
eBPF for Go
import "github.com/cilium/ebpf"
spec, err := ebpf.LoadCollectionSpec("prog.o")
coll, err := ebpf.NewCollection(spec)
defer coll.Close()
prog := coll.Programs["xdp_prog"]
link, err := link.AttachXDP(link.XDPOptions{
Program: prog,
Interface: iface.Index,
})
defer link.Close()
Other Tools
- Cilium: Container networking with eBPF
- Katran: Layer 4 load balancer (Facebook)
- Falco: Runtime security monitoring
- Pixie: Observability platform
- bpftrace: High-level tracing language
Writing eBPF Programs
Development Workflow
- Write C code with eBPF program
- Compile to eBPF bytecode using Clang/LLVM
- Load into kernel using libbpf/BCC
- Attach to hook point
- Communicate via maps
- Unload/detach when done
Basic C Program Structure
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
// Define map
struct {
__uint(type, BPF_MAP_TYPE_HASH);
__uint(max_entries, 1024);
__type(key, u32);
__type(value, u64);
} stats SEC(".maps");
// eBPF program
SEC("xdp")
int xdp_main(struct xdp_md *ctx) {
u32 key = 0;
u64 *count;
count = bpf_map_lookup_elem(&stats, &key);
if (count) {
__sync_fetch_and_add(count, 1);
}
return XDP_PASS;
}
char LICENSE[] SEC("license") = "GPL";
Compilation
# Compile to eBPF bytecode
clang -O2 -g -target bpf -c prog.c -o prog.o
# With BTF (Type Information)
clang -O2 -g -target bpf -D__TARGET_ARCH_x86 \
-I/usr/include/bpf -c prog.c -o prog.o
CO-RE (Compile Once - Run Everywhere)
Problem: Kernel data structures change across versions Solution: BTF (BPF Type Format) + CO-RE relocations
#include <vmlinux.h>
#include <bpf/bpf_core_read.h>
SEC("kprobe/tcp_connect")
int trace_connect(struct pt_regs *ctx) {
struct sock *sk = (struct sock *)PT_REGS_PARM1(ctx);
u16 family;
// CO-RE read - portable across kernel versions
BPF_CORE_READ_INTO(&family, sk, __sk_common.skc_family);
return 0;
}
Generate vmlinux.h (kernel type definitions):
bpftool btf dump file /sys/kernel/btf/vmlinux format c > vmlinux.h
User-Space Loader (libbpf)
#include <bpf/libbpf.h>
#include <bpf/bpf.h>
int main() {
struct bpf_object *obj;
struct bpf_program *prog;
int prog_fd, map_fd;
// Open and load
obj = bpf_object__open_file("prog.o", NULL);
bpf_object__load(obj);
// Get program
prog = bpf_object__find_program_by_name(obj, "xdp_main");
prog_fd = bpf_program__fd(prog);
// Get map
map_fd = bpf_object__find_map_fd_by_name(obj, "stats");
// Attach (XDP example)
int ifindex = if_nametoindex("eth0");
bpf_xdp_attach(ifindex, prog_fd, XDP_FLAGS_UPDATE_IF_NOEXIST, NULL);
// Read from map
u32 key = 0;
u64 value;
bpf_map_lookup_elem(map_fd, &key, &value);
printf("Count: %llu\n", value);
// Cleanup
bpf_xdp_detach(ifindex, XDP_FLAGS_UPDATE_IF_NOEXIST, NULL);
bpf_object__close(obj);
return 0;
}
Compile user-space loader:
gcc -o loader loader.c -lbpf -lelf -lz
Common Use Cases
1. Network Packet Filtering
XDP-based firewall:
- Drop malicious packets at driver level
- Block by IP, port, protocol
- DDoS mitigation
2. Load Balancing
Layer 4 load balancing:
- Distribute connections across backends
- Connection tracking
- Health checks
Examples: Katran (Facebook), Cilium
3. Observability and Tracing
System call tracing:
- Monitor file access
- Track network connections
- Profile CPU usage
Tools: BCC tools (execsnoop, opensnoop, tcpconnect)
4. Security Monitoring
Runtime security:
- Detect malicious behavior
- File integrity monitoring
- Process ancestry tracking
Tools: Falco, Tracee
5. Performance Analysis
Profiling:
- CPU flame graphs
- I/O latency
- Memory allocation tracking
6. Container Networking
CNI plugins:
- Pod networking
- Network policies
- Service mesh data plane
Examples: Cilium, Calico eBPF
7. Network Monitoring
Metrics collection:
- Packet counters
- Bandwidth monitoring
- Protocol analysis
Examples
Example 1: Packet Counter (XDP)
prog.c:
#include <linux/bpf.h>
#include <linux/if_ether.h>
#include <linux/ip.h>
#include <bpf/bpf_helpers.h>
struct {
__uint(type, BPF_MAP_TYPE_ARRAY);
__uint(max_entries, 256);
__type(key, u32);
__type(value, u64);
} proto_count SEC(".maps");
SEC("xdp")
int count_packets(struct xdp_md *ctx) {
void *data_end = (void *)(long)ctx->data_end;
void *data = (void *)(long)ctx->data;
struct ethhdr *eth = data;
if ((void *)(eth + 1) > data_end)
return XDP_PASS;
if (eth->h_proto != __constant_htons(ETH_P_IP))
return XDP_PASS;
struct iphdr *ip = (void *)(eth + 1);
if ((void *)(ip + 1) > data_end)
return XDP_PASS;
u32 key = ip->protocol;
u64 *count = bpf_map_lookup_elem(&proto_count, &key);
if (count)
__sync_fetch_and_add(count, 1);
return XDP_PASS;
}
char LICENSE[] SEC("license") = "GPL";
Compile and load:
clang -O2 -g -target bpf -c prog.c -o prog.o
ip link set dev eth0 xdp obj prog.o sec xdp
Read stats:
bpftool map dump name proto_count
Example 2: Process Execution Tracer
execsnoop.c:
#include <vmlinux.h>
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_core_read.h>
struct event {
u32 pid;
char comm[16];
};
struct {
__uint(type, BPF_MAP_TYPE_RINGBUF);
__uint(max_entries, 256 * 1024);
} events SEC(".maps");
SEC("tracepoint/syscalls/sys_enter_execve")
int trace_execve(struct trace_event_raw_sys_enter *ctx) {
struct event *e;
e = bpf_ringbuf_reserve(&events, sizeof(*e), 0);
if (!e)
return 0;
e->pid = bpf_get_current_pid_tgid() >> 32;
bpf_get_current_comm(&e->comm, sizeof(e->comm));
bpf_ringbuf_submit(e, 0);
return 0;
}
char LICENSE[] SEC("license") = "GPL";
User-space consumer:
#include <bpf/libbpf.h>
#include <bpf/bpf.h>
struct event {
u32 pid;
char comm[16];
};
int handle_event(void *ctx, void *data, size_t len) {
struct event *e = data;
printf("PID: %d, COMM: %s\n", e->pid, e->comm);
return 0;
}
int main() {
struct bpf_object *obj;
struct ring_buffer *rb;
int map_fd;
obj = bpf_object__open_file("execsnoop.o", NULL);
bpf_object__load(obj);
map_fd = bpf_object__find_map_fd_by_name(obj, "events");
rb = ring_buffer__new(map_fd, handle_event, NULL, NULL);
while (1) {
ring_buffer__poll(rb, 100);
}
return 0;
}
Example 3: TCP Connection Tracking
tcpconnect.c:
#include <vmlinux.h>
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_core_read.h>
struct conn_event {
u32 pid;
u32 saddr;
u32 daddr;
u16 sport;
u16 dport;
};
struct {
__uint(type, BPF_MAP_TYPE_RINGBUF);
__uint(max_entries, 256 * 1024);
} events SEC(".maps");
SEC("kprobe/tcp_connect")
int trace_connect(struct pt_regs *ctx) {
struct sock *sk = (struct sock *)PT_REGS_PARM1(ctx);
struct conn_event *e;
u16 family;
BPF_CORE_READ_INTO(&family, sk, __sk_common.skc_family);
if (family != AF_INET)
return 0;
e = bpf_ringbuf_reserve(&events, sizeof(*e), 0);
if (!e)
return 0;
e->pid = bpf_get_current_pid_tgid() >> 32;
BPF_CORE_READ_INTO(&e->saddr, sk, __sk_common.skc_rcv_saddr);
BPF_CORE_READ_INTO(&e->daddr, sk, __sk_common.skc_daddr);
BPF_CORE_READ_INTO(&e->sport, sk, __sk_common.skc_num);
BPF_CORE_READ_INTO(&e->dport, sk, __sk_common.skc_dport);
e->dport = __bpf_ntohs(e->dport);
bpf_ringbuf_submit(e, 0);
return 0;
}
char LICENSE[] SEC("license") = "GPL";
Example 4: Simple LSM Hook
file_access.c (kernel 5.7+):
#include <vmlinux.h>
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_core_read.h>
SEC("lsm/file_open")
int BPF_PROG(restrict_file_open, struct file *file, int ret) {
const char *filename;
char comm[16];
char name[256];
if (ret != 0)
return ret;
bpf_get_current_comm(&comm, sizeof(comm));
filename = BPF_CORE_READ(file, f_path.dentry, d_name.name);
bpf_probe_read_kernel_str(name, sizeof(name), filename);
// Block access to /etc/shadow for specific process
if (__builtin_memcmp(name, "shadow", 6) == 0) {
bpf_printk("Blocked access to %s by %s\n", name, comm);
return -1; // EPERM
}
return 0;
}
char LICENSE[] SEC("license") = "GPL";
Security and Safety
Verifier Guarantees
The eBPF verifier ensures:
-
Memory Safety
- No out-of-bounds access
- All memory access through pointers is validated
- Null pointer checks required
-
Termination
- Bounded loops (kernel 5.3+) or loop unrolling
- No infinite loops
- Limited complexity (instruction count)
-
No Undefined Behavior
- All code paths return a value
- No unreachable code
- Register initialization checked
Verifier Checks
// ❌ BAD: Unbounded loop (pre-5.3)
for (int i = 0; i < n; i++) { }
// ✅ GOOD: Bounded loop
#pragma unroll
for (int i = 0; i < 10; i++) { }
// ✅ GOOD: Bounded with verifier check (5.3+)
for (int i = 0; i < n && i < 100; i++) { }
// ❌ BAD: Unchecked pointer
void *data = (void *)(long)ctx->data;
struct ethhdr *eth = data;
return eth->h_proto; // Verifier error!
// ✅ GOOD: Bounds check
void *data = (void *)(long)ctx->data;
void *data_end = (void *)(long)ctx->data_end;
struct ethhdr *eth = data;
if ((void *)(eth + 1) > data_end)
return XDP_DROP;
return eth->h_proto;
Required Capabilities
Loading eBPF programs requires:
CAP_BPF(kernel 5.8+) for eBPF operationsCAP_PERFMONfor tracing programsCAP_NET_ADMINfor networking programs
Legacy (pre-5.8): CAP_SYS_ADMIN required
# Grant specific capabilities
setcap cap_bpf,cap_perfmon,cap_net_admin+eip ./my_program
Unprivileged eBPF
Limited eBPF for unprivileged users (disabled by default):
# Enable (use with caution)
sysctl kernel.unprivileged_bpf_disabled=0
# Disable (recommended)
sysctl kernel.unprivileged_bpf_disabled=1
Restrictions
- Limited helper functions (no arbitrary kernel memory access)
- No direct kernel pointer access
- Stack size limited to 512 bytes
- Program size limits (1M instructions)
- Map size limits (configurable)
Debugging
Common Verifier Errors
1. Invalid memory access
R0 invalid mem access 'inv'
Solution: Add bounds checks before pointer dereference
2. Unreachable instructions
unreachable insn 123
Solution: Ensure all code paths are reachable
3. Infinite loop detected
back-edge from insn 45 to 12
Solution: Add loop bounds or use #pragma unroll
4. Invalid register state
R1 !read_ok
Solution: Initialize register before use
Debugging Techniques
1. bpf_printk (Kernel Tracing)
bpf_printk("Debug: value=%d\n", value);
Read output:
cat /sys/kernel/debug/tracing/trace_pipe
# or
bpftool prog tracelog
Limitations:
- Limited format strings
- Performance overhead
- Max 3 arguments
2. bpftool Inspection
# Dump translated bytecode
bpftool prog dump xlated id 123
# Dump JIT code
bpftool prog dump jited id 123
# Show verifier log
bpftool prog load prog.o /sys/fs/bpf/prog 2>&1 | less
3. Verbose Verifier Output
// In user-space loader
LIBBPF_OPTS(bpf_object_open_opts, opts,
.kernel_log_level = 1 | 2 | 4, // Verbosity levels
);
obj = bpf_object__open_file("prog.o", &opts);
Or with bpftool:
bpftool -d prog load prog.o /sys/fs/bpf/prog
4. Map Debugging
# Dump all map entries
bpftool map dump id 123
# Update map entry
bpftool map update id 123 key 0 0 0 0 value 1 0 0 0 0 0 0 0
# Delete entry
bpftool map delete id 123 key 0 0 0 0
5. Statistics
# Enable statistics
bpftool feature probe kernel | grep stats
sysctl -w kernel.bpf_stats_enabled=1
# View program stats (run count, runtime)
bpftool prog show id 123
Performance Profiling
1. Measure Program Runtime
u64 start = bpf_ktime_get_ns();
// ... program logic ...
u64 duration = bpf_ktime_get_ns() - start;
2. Use perf with eBPF
# Profile eBPF program
perf record -e bpf:bpf_prog_run -a
perf report
Common Issues
Issue: Program rejected by verifier
- Check: Verifier log for specific error
- Solutions: Add bounds checks, limit loop iterations, reduce complexity
Issue: Map update fails
- Check: Map is full, wrong flags
- Solutions: Use LRU maps, increase size, check update flags
Issue: Helper function not found
- Check: Kernel version, program type
- Solutions: Update kernel, use available helpers for program type
Issue: BTF/CO-RE errors
- Check: BTF available (
/sys/kernel/btf/vmlinux) - Solutions: Enable CONFIG_DEBUG_INFO_BTF, use correct libbpf version
Resources
Documentation
- Official eBPF Docs: https://ebpf.io/
- Kernel Documentation: https://www.kernel.org/doc/html/latest/bpf/
- BPF and XDP Reference Guide: https://docs.cilium.io/en/latest/bpf/
- libbpf Documentation: https://libbpf.readthedocs.io/
Books
- "Learning eBPF" by Liz Rice (O'Reilly, 2023)
- "BPF Performance Tools" by Brendan Gregg (Addison-Wesley, 2019)
- "Linux Observability with BPF" by David Calavera & Lorenzo Fontana (O'Reilly, 2019)
Key Projects
- BCC: https://github.com/iovisor/bcc
- libbpf: https://github.com/libbpf/libbpf
- bpftool: https://github.com/libbpf/bpftool
- Cilium: https://github.com/cilium/cilium
- Katran: https://github.com/facebookincubator/katran
- Falco: https://github.com/falcosecurity/falco
- bpftrace: https://github.com/iovisor/bpftrace
Example Collections
- BCC Tools: https://github.com/iovisor/bcc/tree/master/tools
- libbpf-bootstrap: https://github.com/libbpf/libbpf-bootstrap
- Linux kernel samples: https://github.com/torvalds/linux/tree/master/samples/bpf
Community
- eBPF Summit: Annual conference
- eBPF Slack: https://ebpf.io/slack
- Mailing List: bpf@vger.kernel.org
- Reddit: r/ebpf
Tutorials
- Cilium eBPF Tutorial: https://github.com/cilium/ebpf-tutorial
- XDP Hands-On Tutorial: https://github.com/xdp-project/xdp-tutorial
- libbpf-bootstrap Examples: Step-by-step guides
Tools and Utilities
# Install development tools (Ubuntu/Debian)
apt install -y clang llvm libelf-dev libz-dev libbpf-dev \
linux-tools-common linux-tools-generic bpftool
# Install BCC
apt install -y bpfcc-tools python3-bpfcc
# Install bpftrace
apt install -y bpftrace
Quick Reference
Common Commands
# List all eBPF programs
bpftool prog list
# List all maps
bpftool map list
# Show program by ID
bpftool prog show id <ID>
# Dump program bytecode
bpftool prog dump xlated id <ID>
# Pin program to filesystem
bpftool prog pin id <ID> /sys/fs/bpf/<name>
# Load program from object file
bpftool prog load prog.o /sys/fs/bpf/myprog
# Attach XDP program
ip link set dev <iface> xdp obj prog.o sec xdp
# Detach XDP program
ip link set dev <iface> xdp off
# Attach TC program
tc qdisc add dev <iface> clsact
tc filter add dev <iface> ingress bpf da obj prog.o
# View trace output
cat /sys/kernel/debug/tracing/trace_pipe
Helper Function Categories
- Map operations:
bpf_map_lookup_elem,bpf_map_update_elem,bpf_map_delete_elem - Time:
bpf_ktime_get_ns,bpf_ktime_get_boot_ns - Process/Thread:
bpf_get_current_pid_tgid,bpf_get_current_uid_gid,bpf_get_current_comm - Tracing:
bpf_probe_read,bpf_probe_read_kernel,bpf_probe_read_user - Networking:
bpf_skb_load_bytes,bpf_skb_store_bytes,bpf_xdp_adjust_head - Output:
bpf_printk,bpf_perf_event_output,bpf_ringbuf_submit - Stack:
bpf_get_stackid,bpf_get_stack
Kernel Version Features
- 3.18 (2014): Initial eBPF support
- 4.1 (2015): BPF maps, tail calls
- 4.4 (2016): XDP support
- 4.8 (2016): Direct packet access
- 4.18 (2018): BTF (BPF Type Format)
- 5.2 (2019): Bounded loops support
- 5.7 (2020): LSM BPF programs
- 5.8 (2020): Ring buffer,
CAP_BPF - 5.13 (2021): Kernel module function calls
- 6.0 (2022): Sleepable programs enhancements
Last Updated: 2024 Kernel Version Coverage: Linux 3.18 - 6.x
Netlink
Introduction
Netlink is a Linux kernel interface used for communication between the kernel and user-space processes, as well as between different user-space processes. It provides a flexible, extensible mechanism for transferring information and is the modern replacement for older interfaces like ioctl, /proc, and sysfs for many kernel subsystems.
What is Netlink?
Netlink is a socket-based Inter-Process Communication (IPC) mechanism that uses a special address family (AF_NETLINK). Unlike traditional sockets that communicate over networks, Netlink sockets facilitate communication between user-space and kernel-space, or even between different user-space processes.
Key Characteristics:
- Bidirectional: Both kernel and user-space can initiate communication
- Asynchronous: Supports event-driven programming model
- Multicast: Kernel can broadcast messages to multiple user-space processes
- Extensible: Easy to add new message types and protocols
- Socket-based: Uses familiar socket API (socket, bind, send, recv)
Why Use Netlink?
Netlink offers several advantages over traditional kernel-userspace communication methods:
| Method | Advantages | Disadvantages |
|---|---|---|
| ioctl | Simple, direct | Limited data transfer, not extensible, version compatibility issues |
| /proc | Human-readable | Text parsing overhead, not suitable for complex data, one-way |
| /sys | Organized, one-value-per-file | Inefficient for bulk operations, read-only limitations |
| Netlink | Flexible, extensible, bidirectional, multicast support | More complex API, steeper learning curve |
Advantages of Netlink:
- Structured Messages: Well-defined binary format with TLV (Type-Length-Value) attributes
- Extensibility: Easy to add new attributes without breaking compatibility
- Asynchronous Notifications: Kernel can push events to user-space
- Multicast Support: One-to-many communication
- Standard Socket API: Familiar programming interface
- Better Performance: No text parsing, efficient binary protocol
- Bidirectional: Both sides can initiate communication
Common Use Cases
Netlink is used extensively throughout the Linux kernel for:
- Network Configuration:
rtnetlinkfor routing, interfaces, addresses (used byipcommand) - Wireless Configuration:
nl80211for WiFi management - Netfilter/iptables: Firewall rule management
- SELinux: Security policy communication
- Audit System: Kernel audit events
- udev Events: Device hotplug notifications
- Task Statistics: Per-process statistics (taskstats)
- Connector: Generic kernel-to-user notifications
- Socket Diagnostics: Detailed socket information
Netlink Architecture Overview
graph TB
subgraph UserSpace["User Space"]
App1[Application 1]
App2[Application 2]
App3[Application 3]
Lib[libnl/pyroute2]
end
subgraph KernelSpace["Kernel Space"]
NLS[Netlink Socket Layer]
subgraph Families["Netlink Families"]
ROUTE[NETLINK_ROUTE<br/>rtnetlink]
GEN[NETLINK_GENERIC<br/>generic netlink]
NF[NETLINK_NETFILTER]
KOBJ[NETLINK_KOBJECT_UEVENT]
DIAG[NETLINK_SOCK_DIAG]
end
subgraph Subsystems["Kernel Subsystems"]
NET[Network Stack]
FW[Netfilter]
UDEV[Device Manager]
end
end
App1 -->|AF_NETLINK| NLS
App2 -->|AF_NETLINK| NLS
App3 -->|AF_NETLINK| NLS
Lib -->|AF_NETLINK| NLS
NLS --> ROUTE
NLS --> GEN
NLS --> NF
NLS --> KOBJ
NLS --> DIAG
ROUTE <--> NET
GEN <--> NET
NF <--> FW
KOBJ <--> UDEV
style UserSpace fill:#E6F3FF
style KernelSpace fill:#FFE6E6
style Families fill:#FFF9E6
style Subsystems fill:#E6FFE6
History and Evolution
- Linux 2.0 (1996): Initial netlink implementation for routing
- Linux 2.2 (1999): Expanded to support multiple protocols
- Linux 2.4 (2001): Generic netlink introduced
- Linux 2.6 (2003): Major expansion, nl80211 for wireless
- Linux 3.x (2011+): Continued expansion, netlink used for most kernel-user communication
- Modern Linux: Primary interface for network configuration, replacing ioctl
Core Concepts
Netlink Socket Family
Netlink uses the AF_NETLINK address family. Creating a netlink socket is similar to creating any other socket:
int sock = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE);
Socket Type:
SOCK_RAW: Used for netlink (not UDP/TCP)SOCK_DGRAM: Also supported, functionally equivalent to SOCK_RAW for netlink
Protocol Parameter: Specifies the netlink family/protocol:
NETLINK_ROUTE- Routing and interface configurationNETLINK_GENERIC- Generic netlinkNETLINK_NETFILTER- Netfilter subsystem- Many others (see complete list below)
Netlink Protocols/Families
The Linux kernel supports numerous netlink families:
| Protocol | Value | Purpose |
|---|---|---|
NETLINK_ROUTE | 0 | Routing and link configuration (rtnetlink) |
NETLINK_UNUSED | 1 | Unused (legacy) |
NETLINK_USERSOCK | 2 | Reserved for user-mode socket protocols |
NETLINK_FIREWALL | 3 | Unused (legacy firewall) |
NETLINK_SOCK_DIAG | 4 | Socket diagnostics |
NETLINK_NFLOG | 5 | Netfilter logging |
NETLINK_XFRM | 6 | IPsec |
NETLINK_SELINUX | 7 | SELinux events |
NETLINK_ISCSI | 8 | iSCSI |
NETLINK_AUDIT | 9 | Kernel audit |
NETLINK_FIB_LOOKUP | 10 | FIB lookup |
NETLINK_CONNECTOR | 11 | Kernel connector |
NETLINK_NETFILTER | 12 | Netfilter subsystem |
NETLINK_IP6_FW | 13 | Unused (legacy IPv6 firewall) |
NETLINK_DNRTMSG | 14 | DECnet routing |
NETLINK_KOBJECT_UEVENT | 15 | Kernel object events (udev) |
NETLINK_GENERIC | 16 | Generic netlink |
NETLINK_SCSITRANSPORT | 18 | SCSI transport |
NETLINK_ECRYPTFS | 19 | eCryptfs |
NETLINK_RDMA | 20 | RDMA |
NETLINK_CRYPTO | 21 | Crypto layer |
Communication Model
Netlink supports several communication patterns:
graph TB
subgraph Pattern1["Unicast (Request-Response)"]
U1[User Process] -->|Request| K1[Kernel]
K1 -->|Response| U1
end
subgraph Pattern2["Multicast (Event Broadcasting)"]
K2[Kernel] -->|Event| M1[Subscribed Process 1]
K2 -->|Event| M2[Subscribed Process 2]
K2 -->|Event| M3[Subscribed Process 3]
end
subgraph Pattern3["User-to-User"]
UU1[User Process 1] -->|Message| UU2[User Process 2]
end
style Pattern1 fill:#E6F3FF
style Pattern2 fill:#FFE6F0
style Pattern3 fill:#E6FFE6
Communication Patterns:
-
Unicast (Request-Response):
- User-space sends request to kernel
- Kernel responds with data
- Example: Getting interface information
-
Multicast (Event Broadcasting):
- Kernel broadcasts events to multiple listeners
- User-space processes subscribe to multicast groups
- Example: Link state changes, route updates
-
User-to-User:
- Communication between user-space processes
- Less common, but supported
- Example: Custom IPC using netlink
Netlink Addressing
Netlink uses a unique addressing scheme:
struct sockaddr_nl {
sa_family_t nl_family; /* AF_NETLINK */
unsigned short nl_pad; /* Zero */
__u32 nl_pid; /* Port ID (process ID or 0) */
__u32 nl_groups; /* Multicast groups mask */
};
Port ID (nl_pid):
- User-space: Typically the process PID, but can be any unique value
- Kernel: Always 0
- Autobind: Use 0 to let kernel assign a unique port ID
- Custom: Can specify any value, but must be unique
Multicast Groups (nl_groups):
- Bitmask of multicast groups to join
- Each bit represents a group (0-31)
- Used for receiving broadcast notifications
- Different for each netlink family
Port ID Assignment
flowchart LR
A[Create Socket] --> B{Specify nl_pid?}
B -->|pid = 0| C[Kernel Auto-assigns<br/>unique PID]
B -->|pid = getpid| D[Use process PID]
B -->|pid = custom| E[Use custom value<br/>must be unique]
C --> F[bind success]
D --> G{PID available?}
E --> H{Value available?}
G -->|Yes| F
G -->|No| I[EADDRINUSE error]
H -->|Yes| F
H -->|No| I
style F fill:#90EE90
style I fill:#FFB6C6
Multicast Groups
Multicast groups allow kernel to broadcast events to multiple user-space listeners:
// Example: Join RTMGRP_LINK group to receive link state changes
struct sockaddr_nl sa = {
.nl_family = AF_NETLINK,
.nl_groups = RTMGRP_LINK | RTMGRP_IPV4_ROUTE
};
bind(sock, (struct sockaddr *)&sa, sizeof(sa));
Common rtnetlink Multicast Groups:
RTMGRP_LINK- Link state changesRTMGRP_NOTIFY- General notificationsRTMGRP_NEIGH- Neighbor table updatesRTMGRP_TC- Traffic controlRTMGRP_IPV4_IFADDR- IPv4 address changesRTMGRP_IPV4_ROUTE- IPv4 routing changesRTMGRP_IPV6_IFADDR- IPv6 address changesRTMGRP_IPV6_ROUTE- IPv6 routing changes
Message Format
Netlink Message Header
Every netlink message starts with a struct nlmsghdr:
struct nlmsghdr {
__u32 nlmsg_len; /* Length of message including header */
__u16 nlmsg_type; /* Message type (protocol specific) */
__u16 nlmsg_flags; /* Additional flags */
__u32 nlmsg_seq; /* Sequence number */
__u32 nlmsg_pid; /* Sender port ID */
};
Field Details:
- nlmsg_len: Total message length in bytes, including header
- nlmsg_type: Message type/command (specific to each netlink family)
- nlmsg_flags: Control flags (request, multi-part, etc.)
- nlmsg_seq: Sequence number for matching requests/responses
- nlmsg_pid: Sender's port ID (0 for kernel, process ID for user-space)
Message Types
Standard Message Types (common across all netlink families):
#define NLMSG_NOOP 0x1 /* Nothing, ignore */
#define NLMSG_ERROR 0x2 /* Error message */
#define NLMSG_DONE 0x3 /* End of multi-part message */
#define NLMSG_OVERRUN 0x4 /* Data lost */
Family-Specific Types: Each netlink family defines its own message types (>= 16)
For rtnetlink (NETLINK_ROUTE):
RTM_NEWLINK // Create/update link
RTM_DELLINK // Delete link
RTM_GETLINK // Get link info
RTM_NEWADDR // Add address
RTM_DELADDR // Delete address
RTM_GETADDR // Get address
RTM_NEWROUTE // Add route
RTM_DELROUTE // Delete route
RTM_GETROUTE // Get route
// ... many more
Message Flags
/* Request flags */
#define NLM_F_REQUEST 0x01 /* Request message */
#define NLM_F_MULTI 0x02 /* Multi-part message */
#define NLM_F_ACK 0x04 /* Request acknowledgment */
#define NLM_F_ECHO 0x08 /* Echo request */
/* Modifiers for GET requests */
#define NLM_F_ROOT 0x100 /* Return complete table */
#define NLM_F_MATCH 0x200 /* Return all matching */
#define NLM_F_ATOMIC 0x400 /* Atomic operation */
#define NLM_F_DUMP (NLM_F_ROOT | NLM_F_MATCH)
/* Modifiers for NEW requests */
#define NLM_F_REPLACE 0x100 /* Replace existing */
#define NLM_F_EXCL 0x200 /* Don't replace if exists */
#define NLM_F_CREATE 0x400 /* Create if doesn't exist */
#define NLM_F_APPEND 0x800 /* Add to end of list */
Common Flag Combinations:
NLM_F_REQUEST | NLM_F_DUMP: Get all entriesNLM_F_REQUEST | NLM_F_ACK: Request with acknowledgmentNLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL: Create only if doesn't existNLM_F_REQUEST | NLM_F_REPLACE: Replace existing entry
Message Structure
graph TB
subgraph Message["Netlink Message"]
direction TB
Header[nlmsghdr<br/>16 bytes]
Payload[Message Payload]
subgraph PayloadDetail["Payload Structure"]
FamilyHdr[Family-specific Header<br/>e.g., ifinfomsg, rtmsg]
Attrs[Attributes TLV]
subgraph AttrDetail["Attributes (TLV Format)"]
Attr1[Attribute 1<br/>rtattr/nlattr]
Attr2[Attribute 2]
Attr3[Attribute 3]
AttrN[...]
end
end
end
Header --> Payload
Payload --> FamilyHdr
FamilyHdr --> Attrs
Attrs --> Attr1
Attrs --> Attr2
Attrs --> Attr3
Attrs --> AttrN
style Header fill:#FFE6E6
style FamilyHdr fill:#E6F3FF
style Attrs fill:#E6FFE6
Netlink Attributes (TLV Format)
Netlink uses Type-Length-Value (TLV) encoding for flexible, extensible message payloads:
/* Old-style attributes */
struct rtattr {
unsigned short rta_len; /* Length including header */
unsigned short rta_type; /* Attribute type */
/* Attribute data follows */
};
/* New-style attributes */
struct nlattr {
__u16 nla_len; /* Length including header */
__u16 nla_type; /* Attribute type */
/* Attribute data follows */
};
Attribute Alignment: All attributes must be aligned to 4-byte boundaries.
Macros for Attribute Manipulation:
/* Attribute length macros */
#define RTA_ALIGNTO 4
#define RTA_ALIGN(len) (((len)+RTA_ALIGNTO-1) & ~(RTA_ALIGNTO-1))
#define RTA_LENGTH(len) (RTA_ALIGN(sizeof(struct rtattr)) + (len))
#define RTA_SPACE(len) RTA_ALIGN(RTA_LENGTH(len))
/* Attribute data access */
#define RTA_DATA(rta) ((void*)(((char*)(rta)) + RTA_LENGTH(0)))
#define RTA_PAYLOAD(rta) ((int)((rta)->rta_len) - RTA_LENGTH(0))
/* Attribute iteration */
#define RTA_OK(rta,len) \
((len) >= (int)sizeof(struct rtattr) && \
(rta)->rta_len >= sizeof(struct rtattr) && \
(rta)->rta_len <= (len))
#define RTA_NEXT(rta,attrlen) \
((attrlen) -= RTA_ALIGN((rta)->rta_len), \
(struct rtattr*)(((char*)(rta)) + RTA_ALIGN((rta)->rta_len)))
Nested Attributes
Attributes can contain other attributes (nesting):
/* Creating nested attribute */
struct rtattr *nest = (struct rtattr *)buffer;
nest->rta_type = IFLA_LINKINFO;
nest->rta_len = RTA_LENGTH(0);
/* Add child attributes */
add_attribute(buffer, IFLA_INFO_KIND, "vlan", 4);
add_attribute(buffer, IFLA_INFO_DATA, &data, sizeof(data));
/* Update nest length */
nest->rta_len = (char *)current_pos - (char *)nest;
Message Alignment and Padding
graph LR
subgraph Msg["Message Layout (bytes)"]
H[Header<br/>0-15]
P1[Payload<br/>16-N]
Pad1[Padding<br/>0-3 bytes]
A1[Attr1 Header<br/>4 bytes]
A1D[Attr1 Data]
Pad2[Padding]
A2[Attr2 Header]
A2D[Attr2 Data]
end
style H fill:#FFE6E6
style P1 fill:#E6F3FF
style Pad1 fill:#FFFFE6
style A1 fill:#E6FFE6
style Pad2 fill:#FFFFE6
Alignment Rules:
- Messages are aligned to 4-byte boundaries (NLMSG_ALIGNTO)
- Attributes are aligned to 4-byte boundaries (RTA_ALIGNTO/NLA_ALIGNTO)
- Padding bytes should be zeroed
- Length fields include the header and data, but not padding
rtnetlink (NETLINK_ROUTE)
rtnetlink is the most commonly used netlink family, providing network configuration capabilities used by tools like ip, ifconfig, and route.
Capabilities
rtnetlink can manage:
- Network interfaces (create, delete, configure)
- IP addresses (add, remove, query)
- Routing tables (add/delete routes)
- Neighbor tables (ARP/NDP)
- Traffic control (qdisc, classes, filters)
- Network namespaces
- Tunnels and virtual interfaces
Message Types
/* Link messages */
RTM_NEWLINK /* Create/modify link */
RTM_DELLINK /* Delete link */
RTM_GETLINK /* Get link info */
RTM_SETLINK /* Set link attributes */
/* Address messages */
RTM_NEWADDR /* Add address */
RTM_DELADDR /* Delete address */
RTM_GETADDR /* Get address */
/* Route messages */
RTM_NEWROUTE /* Add route */
RTM_DELROUTE /* Delete route */
RTM_GETROUTE /* Get route */
/* Neighbor messages */
RTM_NEWNEIGH /* Add neighbor */
RTM_DELNEIGH /* Delete neighbor */
RTM_GETNEIGH /* Get neighbor */
/* Rule messages */
RTM_NEWRULE /* Add routing rule */
RTM_DELRULE /* Delete routing rule */
RTM_GETRULE /* Get routing rule */
/* Qdisc messages */
RTM_NEWQDISC /* Add qdisc */
RTM_DELQDISC /* Delete qdisc */
RTM_GETQDISC /* Get qdisc */
/* Traffic class messages */
RTM_NEWTCLASS /* Add traffic class */
RTM_DELTCLASS /* Delete traffic class */
RTM_GETTCLASS /* Get traffic class */
/* Filter messages */
RTM_NEWTFILTER /* Add filter */
RTM_DELTFILTER /* Delete filter */
RTM_GETTFILTER /* Get filter */
Link Management
Interface Information Message (ifinfomsg):
struct ifinfomsg {
unsigned char ifi_family; /* AF_UNSPEC */
unsigned char __ifi_pad;
unsigned short ifi_type; /* Device type (ARPHRD_*) */
int ifi_index; /* Interface index */
unsigned int ifi_flags; /* Device flags (IFF_*) */
unsigned int ifi_change; /* Change mask */
};
Link Attributes:
enum {
IFLA_UNSPEC,
IFLA_ADDRESS, /* Hardware address */
IFLA_BROADCAST, /* Broadcast address */
IFLA_IFNAME, /* Interface name */
IFLA_MTU, /* MTU */
IFLA_LINK, /* Link index */
IFLA_QDISC, /* Queueing discipline */
IFLA_STATS, /* Interface statistics */
IFLA_MASTER, /* Master device index */
IFLA_OPERSTATE, /* Operating state */
IFLA_LINKMODE, /* Link mode */
IFLA_LINKINFO, /* Link type info (nested) */
IFLA_TXQLEN, /* Transmit queue length */
IFLA_MAP, /* Device mapping */
IFLA_WEIGHT, /* Weight */
// ... many more
};
Example: Getting Link Information
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
#include <sys/socket.h>
#include <unistd.h>
#include <stdio.h>
#include <string.h>
int main() {
int sock;
struct {
struct nlmsghdr nlh;
struct ifinfomsg ifi;
} req;
char buf[8192];
struct iovec iov;
struct msghdr msg;
/* Create netlink socket */
sock = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE);
if (sock < 0) {
perror("socket");
return 1;
}
/* Prepare request message */
memset(&req, 0, sizeof(req));
req.nlh.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifinfomsg));
req.nlh.nlmsg_type = RTM_GETLINK;
req.nlh.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
req.nlh.nlmsg_seq = 1;
req.nlh.nlmsg_pid = getpid();
req.ifi.ifi_family = AF_UNSPEC;
/* Send request */
iov.iov_base = &req;
iov.iov_len = req.nlh.nlmsg_len;
memset(&msg, 0, sizeof(msg));
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
if (sendmsg(sock, &msg, 0) < 0) {
perror("sendmsg");
close(sock);
return 1;
}
/* Receive response */
while (1) {
struct nlmsghdr *nlh;
int len;
iov.iov_base = buf;
iov.iov_len = sizeof(buf);
len = recvmsg(sock, &msg, 0);
if (len < 0) {
perror("recvmsg");
break;
}
for (nlh = (struct nlmsghdr *)buf;
NLMSG_OK(nlh, len);
nlh = NLMSG_NEXT(nlh, len)) {
if (nlh->nlmsg_type == NLMSG_DONE) {
goto done;
}
if (nlh->nlmsg_type == NLMSG_ERROR) {
fprintf(stderr, "Error in netlink response\n");
goto done;
}
if (nlh->nlmsg_type == RTM_NEWLINK) {
struct ifinfomsg *ifi = NLMSG_DATA(nlh);
struct rtattr *rta = IFLA_RTA(ifi);
int rtalen = IFLA_PAYLOAD(nlh);
printf("Interface %d: ", ifi->ifi_index);
/* Parse attributes */
while (RTA_OK(rta, rtalen)) {
if (rta->rta_type == IFLA_IFNAME) {
printf("%s ", (char *)RTA_DATA(rta));
} else if (rta->rta_type == IFLA_MTU) {
printf("MTU=%u ", *(unsigned int *)RTA_DATA(rta));
} else if (rta->rta_type == IFLA_OPERSTATE) {
unsigned char state = *(unsigned char *)RTA_DATA(rta);
printf("State=%s ",
state == 6 ? "UP" :
state == 2 ? "DOWN" : "UNKNOWN");
}
rta = RTA_NEXT(rta, rtalen);
}
printf("\n");
}
}
}
done:
close(sock);
return 0;
}
Example: Setting Link UP/DOWN
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
#include <net/if.h>
#include <sys/socket.h>
#include <string.h>
#include <stdio.h>
#include <unistd.h>
int set_link_state(const char *ifname, int up) {
int sock;
struct {
struct nlmsghdr nlh;
struct ifinfomsg ifi;
char attrbuf[512];
} req;
struct rtattr *rta;
int ifindex;
/* Get interface index */
ifindex = if_nametoindex(ifname);
if (ifindex == 0) {
perror("if_nametoindex");
return -1;
}
/* Create socket */
sock = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE);
if (sock < 0) {
perror("socket");
return -1;
}
/* Prepare request */
memset(&req, 0, sizeof(req));
req.nlh.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifinfomsg));
req.nlh.nlmsg_type = RTM_NEWLINK;
req.nlh.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
req.nlh.nlmsg_seq = 1;
req.nlh.nlmsg_pid = getpid();
req.ifi.ifi_family = AF_UNSPEC;
req.ifi.ifi_index = ifindex;
req.ifi.ifi_flags = up ? IFF_UP : 0;
req.ifi.ifi_change = IFF_UP;
/* Send request */
if (send(sock, &req, req.nlh.nlmsg_len, 0) < 0) {
perror("send");
close(sock);
return -1;
}
/* Wait for acknowledgment */
char buf[4096];
int len = recv(sock, buf, sizeof(buf), 0);
struct nlmsghdr *nlh = (struct nlmsghdr *)buf;
if (nlh->nlmsg_type == NLMSG_ERROR) {
struct nlmsgerr *err = (struct nlmsgerr *)NLMSG_DATA(nlh);
if (err->error != 0) {
fprintf(stderr, "Netlink error: %d\n", err->error);
close(sock);
return -1;
}
}
close(sock);
return 0;
}
int main(int argc, char *argv[]) {
if (argc != 3) {
fprintf(stderr, "Usage: %s <interface> <up|down>\n", argv[0]);
return 1;
}
int up = strcmp(argv[2], "up") == 0;
if (set_link_state(argv[1], up) == 0) {
printf("Successfully set %s %s\n", argv[1], up ? "UP" : "DOWN");
return 0;
}
return 1;
}
Address Management
Address Information Message (ifaddrmsg):
struct ifaddrmsg {
__u8 ifa_family; /* Address family (AF_INET/AF_INET6) */
__u8 ifa_prefixlen; /* Prefix length */
__u8 ifa_flags; /* Address flags (IFA_F_*) */
__u8 ifa_scope; /* Address scope (RT_SCOPE_*) */
__u32 ifa_index; /* Interface index */
};
Address Attributes:
enum {
IFA_UNSPEC,
IFA_ADDRESS, /* Address itself */
IFA_LOCAL, /* Local address */
IFA_LABEL, /* Interface name */
IFA_BROADCAST, /* Broadcast address */
IFA_ANYCAST, /* Anycast address */
IFA_CACHEINFO, /* Address cache info */
IFA_MULTICAST, /* Multicast address */
IFA_FLAGS, /* Extended flags */
// ...
};
Example: Adding an IP Address
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
#include <net/if.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <string.h>
#include <stdio.h>
#include <unistd.h>
int add_ipv4_address(const char *ifname, const char *ip, int prefixlen) {
int sock;
struct {
struct nlmsghdr nlh;
struct ifaddrmsg ifa;
char attrbuf[512];
} req;
struct rtattr *rta;
int ifindex;
struct in_addr addr;
/* Get interface index */
ifindex = if_nametoindex(ifname);
if (ifindex == 0) {
perror("if_nametoindex");
return -1;
}
/* Parse IP address */
if (inet_pton(AF_INET, ip, &addr) != 1) {
fprintf(stderr, "Invalid IP address\n");
return -1;
}
/* Create socket */
sock = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE);
if (sock < 0) {
perror("socket");
return -1;
}
/* Prepare request */
memset(&req, 0, sizeof(req));
req.nlh.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifaddrmsg));
req.nlh.nlmsg_type = RTM_NEWADDR;
req.nlh.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK;
req.nlh.nlmsg_seq = 1;
req.nlh.nlmsg_pid = getpid();
req.ifa.ifa_family = AF_INET;
req.ifa.ifa_prefixlen = prefixlen;
req.ifa.ifa_flags = IFA_F_PERMANENT;
req.ifa.ifa_scope = RT_SCOPE_UNIVERSE;
req.ifa.ifa_index = ifindex;
/* Add IFA_LOCAL attribute */
rta = (struct rtattr *)(((char *)&req) + NLMSG_ALIGN(req.nlh.nlmsg_len));
rta->rta_type = IFA_LOCAL;
rta->rta_len = RTA_LENGTH(sizeof(addr));
memcpy(RTA_DATA(rta), &addr, sizeof(addr));
req.nlh.nlmsg_len = NLMSG_ALIGN(req.nlh.nlmsg_len) + RTA_LENGTH(sizeof(addr));
/* Add IFA_ADDRESS attribute */
rta = (struct rtattr *)(((char *)&req) + NLMSG_ALIGN(req.nlh.nlmsg_len));
rta->rta_type = IFA_ADDRESS;
rta->rta_len = RTA_LENGTH(sizeof(addr));
memcpy(RTA_DATA(rta), &addr, sizeof(addr));
req.nlh.nlmsg_len = NLMSG_ALIGN(req.nlh.nlmsg_len) + RTA_LENGTH(sizeof(addr));
/* Send request */
if (send(sock, &req, req.nlh.nlmsg_len, 0) < 0) {
perror("send");
close(sock);
return -1;
}
/* Check acknowledgment */
char buf[4096];
int len = recv(sock, buf, sizeof(buf), 0);
struct nlmsghdr *nlh = (struct nlmsghdr *)buf;
if (nlh->nlmsg_type == NLMSG_ERROR) {
struct nlmsgerr *err = (struct nlmsgerr *)NLMSG_DATA(nlh);
if (err->error != 0) {
fprintf(stderr, "Failed to add address: %s\n", strerror(-err->error));
close(sock);
return -1;
}
}
close(sock);
return 0;
}
int main(int argc, char *argv[]) {
if (argc != 4) {
fprintf(stderr, "Usage: %s <interface> <ip> <prefixlen>\n", argv[0]);
fprintf(stderr, "Example: %s eth0 192.168.1.100 24\n", argv[0]);
return 1;
}
int prefixlen = atoi(argv[3]);
if (add_ipv4_address(argv[1], argv[2], prefixlen) == 0) {
printf("Successfully added %s/%d to %s\n", argv[2], prefixlen, argv[1]);
return 0;
}
return 1;
}
Route Management
Route Message (rtmsg):
struct rtmsg {
unsigned char rtm_family; /* Address family (AF_INET/AF_INET6) */
unsigned char rtm_dst_len; /* Destination prefix length */
unsigned char rtm_src_len; /* Source prefix length */
unsigned char rtm_tos; /* Type of service */
unsigned char rtm_table; /* Routing table ID */
unsigned char rtm_protocol; /* Routing protocol */
unsigned char rtm_scope; /* Route scope */
unsigned char rtm_type; /* Route type */
unsigned int rtm_flags; /* Route flags */
};
Route Attributes:
enum {
RTA_UNSPEC,
RTA_DST, /* Destination address */
RTA_SRC, /* Source address */
RTA_IIF, /* Input interface */
RTA_OIF, /* Output interface */
RTA_GATEWAY, /* Gateway address */
RTA_PRIORITY, /* Route priority/metric */
RTA_PREFSRC, /* Preferred source address */
RTA_METRICS, /* Route metrics */
RTA_MULTIPATH, /* Multipath route */
RTA_FLOW, /* Flow classification */
RTA_CACHEINFO, /* Cache information */
RTA_TABLE, /* Routing table ID (extended) */
// ... more
};
Example: Adding a Route
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
#include <net/if.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <string.h>
#include <stdio.h>
#include <unistd.h>
int add_route(const char *dest, int prefixlen, const char *gateway, const char *ifname) {
int sock;
struct {
struct nlmsghdr nlh;
struct rtmsg rtm;
char attrbuf[512];
} req;
struct rtattr *rta;
struct in_addr dst_addr, gw_addr;
int ifindex;
/* Parse addresses */
if (inet_pton(AF_INET, dest, &dst_addr) != 1) {
fprintf(stderr, "Invalid destination address\n");
return -1;
}
if (gateway && inet_pton(AF_INET, gateway, &gw_addr) != 1) {
fprintf(stderr, "Invalid gateway address\n");
return -1;
}
/* Get interface index if specified */
if (ifname) {
ifindex = if_nametoindex(ifname);
if (ifindex == 0) {
perror("if_nametoindex");
return -1;
}
}
/* Create socket */
sock = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE);
if (sock < 0) {
perror("socket");
return -1;
}
/* Prepare request */
memset(&req, 0, sizeof(req));
req.nlh.nlmsg_len = NLMSG_LENGTH(sizeof(struct rtmsg));
req.nlh.nlmsg_type = RTM_NEWROUTE;
req.nlh.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK;
req.nlh.nlmsg_seq = 1;
req.nlh.nlmsg_pid = getpid();
req.rtm.rtm_family = AF_INET;
req.rtm.rtm_dst_len = prefixlen;
req.rtm.rtm_table = RT_TABLE_MAIN;
req.rtm.rtm_protocol = RTPROT_BOOT;
req.rtm.rtm_scope = RT_SCOPE_UNIVERSE;
req.rtm.rtm_type = RTN_UNICAST;
/* Add RTA_DST attribute */
rta = (struct rtattr *)(((char *)&req) + NLMSG_ALIGN(req.nlh.nlmsg_len));
rta->rta_type = RTA_DST;
rta->rta_len = RTA_LENGTH(sizeof(dst_addr));
memcpy(RTA_DATA(rta), &dst_addr, sizeof(dst_addr));
req.nlh.nlmsg_len = NLMSG_ALIGN(req.nlh.nlmsg_len) + RTA_LENGTH(sizeof(dst_addr));
/* Add RTA_GATEWAY attribute if specified */
if (gateway) {
rta = (struct rtattr *)(((char *)&req) + NLMSG_ALIGN(req.nlh.nlmsg_len));
rta->rta_type = RTA_GATEWAY;
rta->rta_len = RTA_LENGTH(sizeof(gw_addr));
memcpy(RTA_DATA(rta), &gw_addr, sizeof(gw_addr));
req.nlh.nlmsg_len = NLMSG_ALIGN(req.nlh.nlmsg_len) + RTA_LENGTH(sizeof(gw_addr));
}
/* Add RTA_OIF attribute if specified */
if (ifname) {
rta = (struct rtattr *)(((char *)&req) + NLMSG_ALIGN(req.nlh.nlmsg_len));
rta->rta_type = RTA_OIF;
rta->rta_len = RTA_LENGTH(sizeof(ifindex));
memcpy(RTA_DATA(rta), &ifindex, sizeof(ifindex));
req.nlh.nlmsg_len = NLMSG_ALIGN(req.nlh.nlmsg_len) + RTA_LENGTH(sizeof(ifindex));
}
/* Send request */
if (send(sock, &req, req.nlh.nlmsg_len, 0) < 0) {
perror("send");
close(sock);
return -1;
}
/* Check acknowledgment */
char buf[4096];
int len = recv(sock, buf, sizeof(buf), 0);
struct nlmsghdr *nlh = (struct nlmsghdr *)buf;
if (nlh->nlmsg_type == NLMSG_ERROR) {
struct nlmsgerr *err = (struct nlmsgerr *)NLMSG_DATA(nlh);
if (err->error != 0) {
fprintf(stderr, "Failed to add route: %s\n", strerror(-err->error));
close(sock);
return -1;
}
}
close(sock);
return 0;
}
int main(int argc, char *argv[]) {
if (argc < 3) {
fprintf(stderr, "Usage: %s <dest> <prefixlen> [gateway] [interface]\n", argv[0]);
fprintf(stderr, "Example: %s 192.168.2.0 24 192.168.1.1 eth0\n", argv[0]);
return 1;
}
const char *dest = argv[1];
int prefixlen = atoi(argv[2]);
const char *gateway = argc > 3 ? argv[3] : NULL;
const char *ifname = argc > 4 ? argv[4] : NULL;
if (add_route(dest, prefixlen, gateway, ifname) == 0) {
printf("Successfully added route\n");
return 0;
}
return 1;
}
Monitoring Link Changes
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
#include <sys/socket.h>
#include <unistd.h>
#include <stdio.h>
#include <string.h>
void monitor_link_changes() {
int sock;
struct sockaddr_nl sa;
char buf[8192];
/* Create socket */
sock = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE);
if (sock < 0) {
perror("socket");
return;
}
/* Bind to multicast groups */
memset(&sa, 0, sizeof(sa));
sa.nl_family = AF_NETLINK;
sa.nl_groups = RTMGRP_LINK | RTMGRP_IPV4_IFADDR | RTMGRP_IPV4_ROUTE;
if (bind(sock, (struct sockaddr *)&sa, sizeof(sa)) < 0) {
perror("bind");
close(sock);
return;
}
printf("Monitoring network changes...\n");
/* Receive and process events */
while (1) {
struct nlmsghdr *nlh;
int len = recv(sock, buf, sizeof(buf), 0);
if (len < 0) {
perror("recv");
break;
}
for (nlh = (struct nlmsghdr *)buf;
NLMSG_OK(nlh, len);
nlh = NLMSG_NEXT(nlh, len)) {
if (nlh->nlmsg_type == RTM_NEWLINK || nlh->nlmsg_type == RTM_DELLINK) {
struct ifinfomsg *ifi = NLMSG_DATA(nlh);
const char *action = nlh->nlmsg_type == RTM_NEWLINK ? "NEW/UPDATE" : "DELETE";
printf("LINK %s: index=%d flags=0x%x\n",
action, ifi->ifi_index, ifi->ifi_flags);
/* Parse attributes */
struct rtattr *rta = IFLA_RTA(ifi);
int rtalen = IFLA_PAYLOAD(nlh);
while (RTA_OK(rta, rtalen)) {
if (rta->rta_type == IFLA_IFNAME) {
printf(" Interface: %s\n", (char *)RTA_DATA(rta));
}
rta = RTA_NEXT(rta, rtalen);
}
} else if (nlh->nlmsg_type == RTM_NEWADDR || nlh->nlmsg_type == RTM_DELADDR) {
struct ifaddrmsg *ifa = NLMSG_DATA(nlh);
const char *action = nlh->nlmsg_type == RTM_NEWADDR ? "NEW" : "DELETE";
printf("ADDR %s: family=%d index=%d\n",
action, ifa->ifa_family, ifa->ifa_index);
} else if (nlh->nlmsg_type == RTM_NEWROUTE || nlh->nlmsg_type == RTM_DELROUTE) {
struct rtmsg *rtm = NLMSG_DATA(nlh);
const char *action = nlh->nlmsg_type == RTM_NEWROUTE ? "NEW" : "DELETE";
printf("ROUTE %s: family=%d dst_len=%d\n",
action, rtm->rtm_family, rtm->rtm_dst_len);
}
}
}
close(sock);
}
int main() {
monitor_link_changes();
return 0;
}
Generic Netlink
Generic Netlink (NETLINK_GENERIC) is a meta-protocol that allows kernel modules to create custom netlink families without needing a dedicated netlink protocol number. It's the recommended way to add new netlink-based interfaces.
Why Generic Netlink?
Traditional Approach Problems:
- Limited number of netlink protocol numbers (0-31)
- Each subsystem needs a dedicated protocol number
- Protocol numbers are a scarce resource
Generic Netlink Solution:
- Multiplexes multiple "families" over a single protocol (
NETLINK_GENERIC) - Dynamic family registration
- Automatic command and attribute validation
- Easier to add new interfaces
Architecture
graph TB
subgraph UserSpace["User Space"]
App[Application]
end
subgraph KernelSpace["Kernel Space"]
GNL[Generic Netlink Core]
subgraph Families["Generic Netlink Families"]
NL80211[nl80211<br/>WiFi]
DEVLINK[devlink<br/>Device Config]
TEAM[team<br/>Link Aggregation]
TASKSTATS[taskstats<br/>Task Statistics]
CUSTOM[Custom Family]
end
end
App -->|NETLINK_GENERIC| GNL
GNL --> NL80211
GNL --> DEVLINK
GNL --> TEAM
GNL --> TASKSTATS
GNL --> CUSTOM
style UserSpace fill:#E6F3FF
style KernelSpace fill:#FFE6E6
style Families fill:#E6FFE6
Generic Netlink Message Structure
struct genlmsghdr {
__u8 cmd; /* Command */
__u8 version; /* Version */
__u16 reserved; /* Reserved */
};
The complete message structure:
+-------------------+
| nlmsghdr | <- Standard netlink header
+-------------------+
| genlmsghdr | <- Generic netlink header
+-------------------+
| Family Attributes | <- Family-specific data (TLV)
+-------------------+
Family Resolution
Before using a generic netlink family, you must resolve its family ID:
#include <linux/genetlink.h>
#define GENL_CTRL_NAME "nlctrl" /* Controller family name */
#define GENL_CTRL_VERSION 2
/* Get family ID by name */
int get_family_id(int sock, const char *family_name) {
struct {
struct nlmsghdr nlh;
struct genlmsghdr gnlh;
char attrbuf[512];
} req;
struct rtattr *rta;
int family_id = -1;
/* Prepare request to controller */
memset(&req, 0, sizeof(req));
req.nlh.nlmsg_len = NLMSG_LENGTH(GENL_HDRLEN);
req.nlh.nlmsg_type = GENL_ID_CTRL; /* Controller family ID is always 0x10 */
req.nlh.nlmsg_flags = NLM_F_REQUEST;
req.nlh.nlmsg_seq = 1;
req.nlh.nlmsg_pid = getpid();
req.gnlh.cmd = CTRL_CMD_GETFAMILY;
req.gnlh.version = GENL_CTRL_VERSION;
/* Add family name attribute */
rta = (struct rtattr *)(((char *)&req) + NLMSG_ALIGN(req.nlh.nlmsg_len));
rta->rta_type = CTRL_ATTR_FAMILY_NAME;
rta->rta_len = RTA_LENGTH(strlen(family_name) + 1);
strcpy(RTA_DATA(rta), family_name);
req.nlh.nlmsg_len = NLMSG_ALIGN(req.nlh.nlmsg_len) + RTA_ALIGN(rta->rta_len);
/* Send request */
if (send(sock, &req, req.nlh.nlmsg_len, 0) < 0) {
return -1;
}
/* Receive response and parse family ID */
char buf[4096];
int len = recv(sock, buf, sizeof(buf), 0);
struct nlmsghdr *nlh = (struct nlmsghdr *)buf;
if (NLMSG_OK(nlh, len) && nlh->nlmsg_type != NLMSG_ERROR) {
struct genlmsghdr *gnlh = (struct genlmsghdr *)NLMSG_DATA(nlh);
rta = (struct rtattr *)((char *)gnlh + GENL_HDRLEN);
int rtalen = nlh->nlmsg_len - NLMSG_LENGTH(GENL_HDRLEN);
while (RTA_OK(rta, rtalen)) {
if (rta->rta_type == CTRL_ATTR_FAMILY_ID) {
family_id = *(__u16 *)RTA_DATA(rta);
break;
}
rta = RTA_NEXT(rta, rtalen);
}
}
return family_id;
}
Example: nl80211 (WiFi Configuration)
nl80211 is one of the most commonly used generic netlink families for WiFi configuration.
Listing WiFi Interfaces:
#include <linux/netlink.h>
#include <linux/genetlink.h>
#include <linux/nl80211.h>
#include <sys/socket.h>
#include <unistd.h>
#include <stdio.h>
#include <string.h>
int list_wifi_interfaces() {
int sock;
struct {
struct nlmsghdr nlh;
struct genlmsghdr gnlh;
} req;
int nl80211_id;
/* Create socket */
sock = socket(AF_NETLINK, SOCK_RAW, NETLINK_GENERIC);
if (sock < 0) {
perror("socket");
return -1;
}
/* Get nl80211 family ID */
nl80211_id = get_family_id(sock, "nl80211");
if (nl80211_id < 0) {
fprintf(stderr, "Failed to get nl80211 family ID\n");
close(sock);
return -1;
}
/* Prepare request */
memset(&req, 0, sizeof(req));
req.nlh.nlmsg_len = NLMSG_LENGTH(GENL_HDRLEN);
req.nlh.nlmsg_type = nl80211_id;
req.nlh.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
req.nlh.nlmsg_seq = 1;
req.nlh.nlmsg_pid = getpid();
req.gnlh.cmd = NL80211_CMD_GET_INTERFACE;
req.gnlh.version = 1;
/* Send request */
if (send(sock, &req, req.nlh.nlmsg_len, 0) < 0) {
perror("send");
close(sock);
return -1;
}
/* Receive and process response */
char buf[8192];
while (1) {
struct nlmsghdr *nlh;
int len = recv(sock, buf, sizeof(buf), 0);
if (len < 0) {
perror("recv");
break;
}
for (nlh = (struct nlmsghdr *)buf;
NLMSG_OK(nlh, len);
nlh = NLMSG_NEXT(nlh, len)) {
if (nlh->nlmsg_type == NLMSG_DONE) {
goto done;
}
if (nlh->nlmsg_type == NLMSG_ERROR) {
fprintf(stderr, "Error in response\n");
goto done;
}
struct genlmsghdr *gnlh = (struct genlmsghdr *)NLMSG_DATA(nlh);
struct rtattr *rta = (struct rtattr *)((char *)gnlh + GENL_HDRLEN);
int rtalen = nlh->nlmsg_len - NLMSG_LENGTH(GENL_HDRLEN);
printf("WiFi Interface:\n");
while (RTA_OK(rta, rtalen)) {
if (rta->rta_type == NL80211_ATTR_IFNAME) {
printf(" Name: %s\n", (char *)RTA_DATA(rta));
} else if (rta->rta_type == NL80211_ATTR_IFINDEX) {
printf(" Index: %u\n", *(__u32 *)RTA_DATA(rta));
} else if (rta->rta_type == NL80211_ATTR_WIPHY) {
printf(" PHY: %u\n", *(__u32 *)RTA_DATA(rta));
}
rta = RTA_NEXT(rta, rtalen);
}
printf("\n");
}
}
done:
close(sock);
return 0;
}
Python Examples with pyroute2
Working with netlink in C can be verbose. Python's pyroute2 library provides a much simpler interface.
Installation
pip install pyroute2
Example: Listing Network Interfaces
from pyroute2 import IPRoute
# Create IPRoute object
ip = IPRoute()
# Get all links
links = ip.get_links()
for link in links:
# Extract attributes
attrs = dict(link['attrs'])
print(f"Interface: {attrs.get('IFLA_IFNAME', 'unknown')}")
print(f" Index: {link['index']}")
print(f" State: {'UP' if link['flags'] & 1 else 'DOWN'}")
print(f" MTU: {attrs.get('IFLA_MTU', 'N/A')}")
if 'IFLA_ADDRESS' in attrs:
mac = ':'.join(f'{b:02x}' for b in attrs['IFLA_ADDRESS'])
print(f" MAC: {mac}")
print()
# Close connection
ip.close()
Example: Adding an IP Address
from pyroute2 import IPRoute
ip = IPRoute()
# Get interface index
idx = ip.link_lookup(ifname='eth0')[0]
# Add IP address
ip.addr('add', index=idx, address='192.168.1.100', prefixlen=24)
# Verify
addrs = ip.get_addr(index=idx)
for addr in addrs:
attrs = dict(addr['attrs'])
if 'IFA_ADDRESS' in attrs:
print(f"Address: {attrs['IFA_ADDRESS']}/{addr['prefixlen']}")
ip.close()
Example: Managing Routes
from pyroute2 import IPRoute
ip = IPRoute()
# Add a route
ip.route('add', dst='192.168.2.0/24', gateway='192.168.1.1')
# List routes
routes = ip.get_routes(family=2) # AF_INET
for route in routes:
attrs = dict(route['attrs'])
dst = attrs.get('RTA_DST', 'default')
gateway = attrs.get('RTA_GATEWAY', 'direct')
print(f"Route: {dst}/{route.get('dst_len', 0)} via {gateway}")
# Delete a route
ip.route('del', dst='192.168.2.0/24', gateway='192.168.1.1')
ip.close()
Example: Monitoring Network Events
from pyroute2 import IPRoute
ip = IPRoute()
# Bind to multicast groups
ip.bind()
print("Monitoring network events... (Ctrl+C to stop)")
try:
for message in ip.get():
event = message.get('event')
if event == 'RTM_NEWLINK':
attrs = dict(message['attrs'])
ifname = attrs.get('IFLA_IFNAME', 'unknown')
print(f"Link added/changed: {ifname}")
elif event == 'RTM_DELLINK':
attrs = dict(message['attrs'])
ifname = attrs.get('IFLA_IFNAME', 'unknown')
print(f"Link deleted: {ifname}")
elif event == 'RTM_NEWADDR':
attrs = dict(message['attrs'])
addr = attrs.get('IFA_ADDRESS', 'N/A')
print(f"Address added: {addr}")
elif event == 'RTM_DELADDR':
attrs = dict(message['attrs'])
addr = attrs.get('IFA_ADDRESS', 'N/A')
print(f"Address deleted: {addr}")
except KeyboardInterrupt:
print("\nStopped monitoring")
ip.close()
Example: Creating a VLAN Interface
from pyroute2 import IPRoute
ip = IPRoute()
try:
# Get parent interface index
parent_idx = ip.link_lookup(ifname='eth0')[0]
# Create VLAN interface
ip.link('add',
ifname='eth0.100',
kind='vlan',
link=parent_idx,
vlan_id=100)
# Get new interface index
vlan_idx = ip.link_lookup(ifname='eth0.100')[0]
# Bring interface up
ip.link('set', index=vlan_idx, state='up')
# Add IP address
ip.addr('add', index=vlan_idx, address='10.0.100.1', prefixlen=24)
print("VLAN interface eth0.100 created successfully")
except Exception as e:
print(f"Error: {e}")
ip.close()
Netlink Libraries
libnl (C Library)
libnl is the standard C library for netlink programming, providing high-level abstractions.
Installation:
# Ubuntu/Debian
sudo apt-get install libnl-3-dev libnl-route-3-dev libnl-genl-3-dev
# Fedora/RHEL
sudo dnf install libnl3-devel
Example:
#include <netlink/netlink.h>
#include <netlink/route/link.h>
int main() {
struct nl_sock *sock;
struct nl_cache *link_cache;
struct rtnl_link *link;
/* Allocate socket */
sock = nl_socket_alloc();
if (!sock) {
return -1;
}
/* Connect to route netlink */
nl_connect(sock, NETLINK_ROUTE);
/* Allocate link cache */
rtnl_link_alloc_cache(sock, AF_UNSPEC, &link_cache);
/* Iterate through links */
link = (struct rtnl_link *)nl_cache_get_first(link_cache);
while (link) {
printf("Interface: %s\n", rtnl_link_get_name(link));
printf(" Index: %d\n", rtnl_link_get_ifindex(link));
printf(" MTU: %u\n", rtnl_link_get_mtu(link));
link = (struct rtnl_link *)nl_cache_get_next((struct nl_object *)link);
}
/* Cleanup */
nl_cache_free(link_cache);
nl_socket_free(sock);
return 0;
}
Compilation:
gcc -o example example.c $(pkg-config --cflags --libs libnl-3.0 libnl-route-3.0)
pyroute2 (Python)
We've already seen several examples above. pyroute2 is the most popular Python library for netlink.
Features:
- IPRoute: Network interface and routing management
- IPDB: Transactional interface for network configuration
- Generic netlink support
- Network namespace support
- Async/await support
Other Libraries
Rust:
netlink-rs: Rust bindings for netlinkrtnetlink: High-level rtnetlink API
Go:
vishvananda/netlink: Popular Go netlink librarymdlayher/netlink: Low-level netlink library
Tools Using Netlink
iproute2
The ip command is the primary tool for network configuration on Linux, using rtnetlink.
Common Commands:
# Link management
ip link show
ip link set eth0 up
ip link set eth0 down
ip link set eth0 mtu 9000
# Address management
ip addr show
ip addr add 192.168.1.100/24 dev eth0
ip addr del 192.168.1.100/24 dev eth0
# Route management
ip route show
ip route add 192.168.2.0/24 via 192.168.1.1
ip route del 192.168.2.0/24
# Neighbor (ARP) management
ip neigh show
ip neigh add 192.168.1.1 lladdr 00:11:22:33:44:55 dev eth0
iw
WiFi configuration tool using nl80211:
# List WiFi devices
iw dev
# Scan for networks
iw dev wlan0 scan
# Connect to network
iw dev wlan0 connect "SSID"
# Get interface info
iw dev wlan0 info
# Set channel
iw dev wlan0 set channel 6
ss (Socket Statistics)
Uses NETLINK_SOCK_DIAG for socket information:
# Show all TCP sockets
ss -t
# Show listening sockets
ss -l
# Show detailed information
ss -e
# Show socket memory usage
ss -m
# Filter by state
ss state established
ethtool
Some operations use netlink (newer versions):
# Show interface statistics
ethtool -S eth0
# Show driver info
ethtool -i eth0
# Set speed/duplex
ethtool -s eth0 speed 1000 duplex full
Advanced Topics
Netlink Error Handling
Netlink errors are returned via NLMSG_ERROR messages:
struct nlmsgerr {
int error; /* Negative errno or 0 for ack */
struct nlmsghdr msg; /* Original request header */
};
Handling Errors:
if (nlh->nlmsg_type == NLMSG_ERROR) {
struct nlmsgerr *err = (struct nlmsgerr *)NLMSG_DATA(nlh);
if (err->error == 0) {
/* Success acknowledgment */
printf("Success\n");
} else {
/* Error occurred */
fprintf(stderr, "Netlink error: %s\n", strerror(-err->error));
}
}
Extended Acknowledgments
Modern kernels support extended acknowledgments with error messages:
/* Request extended ack */
int val = 1;
setsockopt(sock, SOL_NETLINK, NETLINK_EXT_ACK, &val, sizeof(val));
When enabled, error messages can include:
- Human-readable error strings
- Attribute that caused the error
- Error offset in message
Multi-part Messages
Large responses are sent as multi-part messages:
/* Request with DUMP flag */
req.nlh.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
/* Receive loop */
while (1) {
len = recv(sock, buf, sizeof(buf), 0);
for (nlh = (struct nlmsghdr *)buf;
NLMSG_OK(nlh, len);
nlh = NLMSG_NEXT(nlh, len)) {
if (nlh->nlmsg_type == NLMSG_DONE) {
goto done; /* End of multi-part */
}
/* Process message */
process_message(nlh);
}
}
Netlink Socket Options
/* Set receive buffer size */
int bufsize = 32768;
setsockopt(sock, SOL_SOCKET, SO_RCVBUF, &bufsize, sizeof(bufsize));
/* Enable broadcast */
int val = 1;
setsockopt(sock, SOL_NETLINK, NETLINK_BROADCAST_ERROR, &val, sizeof(val));
/* Enable listening to all namespaces */
setsockopt(sock, SOL_NETLINK, NETLINK_LISTEN_ALL_NSID, &val, sizeof(val));
/* Disable auto-ack */
val = 0;
setsockopt(sock, SOL_NETLINK, NETLINK_NO_ENOBUFS, &val, sizeof(val));
Network Namespaces
Netlink operates within network namespaces:
/* Open namespace file descriptor */
int nsfd = open("/var/run/netns/myns", O_RDONLY);
/* Switch to namespace */
setns(nsfd, CLONE_NEWNET);
/* Now netlink operations affect the new namespace */
int sock = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE);
/* ... */
Python Example:
from pyroute2 import NetNS
# Open namespace
ns = NetNS('myns')
# List interfaces in namespace
links = ns.get_links()
# Close namespace
ns.close()
Performance Considerations
Batching Requests:
/* Send multiple requests in one syscall */
struct iovec iov[10];
for (int i = 0; i < 10; i++) {
/* Prepare each message */
iov[i].iov_base = &requests[i];
iov[i].iov_len = requests[i].nlh.nlmsg_len;
}
struct msghdr msg = {
.msg_iov = iov,
.msg_iovlen = 10,
};
sendmsg(sock, &msg, 0);
Buffer Size:
- Use large buffers (32KB+) for DUMP operations
- Set SO_RCVBUF to avoid message drops
- Monitor ENOBUFS errors
Message Size:
- Keep messages under page size (4KB) when possible
- Use NLM_F_MULTI for large data transfers
Security Considerations
Capabilities Required:
- Most netlink operations require
CAP_NET_ADMIN - Read-only operations (GET) typically allowed for all users
- Modify operations (NEW/DEL/SET) require privileges
Checking Permissions:
#include <sys/capability.h>
int check_net_admin() {
cap_t caps = cap_get_proc();
cap_flag_value_t value;
cap_get_flag(caps, CAP_NET_ADMIN, CAP_EFFECTIVE, &value);
cap_free(caps);
return value == CAP_SET;
}
Port ID Validation:
- Always validate sender's port ID
- Kernel messages always have nl_pid = 0
- User messages should match their PID
Debugging Netlink
Using strace
# Trace netlink syscalls
strace -e sendto,recvfrom,bind,socket ip link show
# Show data in hex
strace -e trace=sendto,recvfrom -x ip addr show
# Follow forks
strace -f -e trace=network ip link show
Using nlmon
Create a netlink monitor interface:
# Load module
modprobe nlmon
# Create interface
ip link add nlmon0 type nlmon
ip link set nlmon0 up
# Capture with tcpdump
tcpdump -i nlmon0 -w netlink.pcap
# Or with Wireshark
wireshark -i nlmon0
Wireshark Dissectors
Wireshark can dissect netlink messages:
- rtnetlink messages
- Generic netlink messages
- nl80211 (WiFi) messages
Manual Parsing
# Dump netlink messages in hex
ip -d link show | od -A x -t x1z -v
# Use hexdump for better formatting
ip link show 2>&1 | hexdump -C
Common Pitfalls
1. Incorrect Message Alignment
Wrong:
req.nlh.nlmsg_len = sizeof(struct nlmsghdr) + sizeof(struct ifinfomsg);
Correct:
req.nlh.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifinfomsg));
2. Not Checking NLMSG_ERROR
Always check for error responses:
if (nlh->nlmsg_type == NLMSG_ERROR) {
struct nlmsgerr *err = NLMSG_DATA(nlh);
if (err->error != 0) {
/* Handle error */
}
}
3. Buffer Too Small
Use adequately sized buffers for DUMP operations:
char buf[32768]; /* 32KB is recommended */
4. Not Handling Multi-part Messages
Always loop until NLMSG_DONE:
while (1) {
for (nlh = ...; NLMSG_OK(nlh, len); nlh = NLMSG_NEXT(nlh, len)) {
if (nlh->nlmsg_type == NLMSG_DONE) goto done;
/* ... */
}
}
5. Incorrect Attribute Parsing
Always use macros for attribute manipulation:
/* Wrong */
rta = (struct rtattr *)((char *)ifi + sizeof(*ifi));
/* Correct */
rta = IFLA_RTA(ifi);
Summary
Netlink is a powerful and flexible IPC mechanism that has become the standard for kernel-userspace communication in Linux. Key takeaways:
Advantages:
- Bidirectional, asynchronous communication
- Multicast support for event notifications
- Extensible TLV format
- Type-safe and efficient binary protocol
Common Use Cases:
- Network configuration (rtnetlink)
- WiFi management (nl80211)
- Firewall rules (netfilter)
- Device events (kobject_uevent)
- Custom kernel modules (generic netlink)
Best Practices:
- Use libraries (libnl, pyroute2) for simpler code
- Always check for errors via NLMSG_ERROR
- Use proper alignment macros
- Handle multi-part messages correctly
- Set appropriate buffer sizes
Resources:
- Kernel Documentation:
Documentation/userspace-api/netlink/ - libnl: https://www.infradead.org/~tgr/libnl/
- pyroute2: https://docs.pyroute2.org/
- iproute2 source code: https://git.kernel.org/pub/scm/network/iproute2/iproute2.git
Netlink continues to evolve, with new families and features being added regularly. Understanding netlink is essential for anyone working with Linux networking, device management, or kernel-userspace communication.
Essential Linux Commands Reference
A comprehensive guide to essential Linux commands with examples, use cases, and practical tips.
Table of Contents
- File System Navigation
- File Operations
- Text Processing
- Search and Find
- Process Management
- System Monitoring
- User Management
- Permissions
- Package Management
- Network Commands
- Service Management
- Compression
- Disk Management
- System Information
File System Navigation
ls - List Directory Contents
# Basic listing
ls # List files in current directory
ls -l # Long format with details
ls -a # Show hidden files
ls -lh # Human-readable sizes
ls -lah # Combine all above options
ls -R # Recursive listing
ls -lt # Sort by modification time
ls -lS # Sort by size
# Advanced usage
ls -i # Show inode numbers
ls -d */ # List only directories
ls --color=auto # Colored output
ls -ltr # Reverse time sort (oldest first)
# Examples
ls *.txt # List all .txt files
ls -l /var/log/ # List files in specific directory
ls -lh --sort=size # Sort by size, human-readable
Use Cases:
- Quick directory overview
- Check file permissions and ownership
- Find recently modified files
- Disk usage analysis
cd - Change Directory
cd /path/to/directory # Absolute path
cd relative/path # Relative path
cd .. # Parent directory
cd ../.. # Two levels up
cd - # Previous directory
cd ~ # Home directory
cd # Home directory (shorthand)
cd ~username # Another user's home
# Examples
cd /var/log # Go to log directory
cd ~/Documents # Go to Documents in home
cd - # Toggle between two directories
pwd - Print Working Directory
pwd # Show current directory
pwd -P # Show physical directory (resolve symlinks)
File Operations
cp - Copy Files
# Basic copying
cp source.txt dest.txt # Copy file
cp file1 file2 dir/ # Copy multiple files to directory
cp -r dir1/ dir2/ # Copy directory recursively
cp -i file dest # Interactive (prompt before overwrite)
cp -v file dest # Verbose output
cp -u file dest # Update (copy only if newer)
# Advanced options
cp -p file dest # Preserve attributes (mode, ownership, timestamps)
cp -a dir1/ dir2/ # Archive mode (recursive + preserve)
cp --backup file dest # Create backup before overwriting
# Examples
cp /etc/config ~/.config/ # Copy config file to home
cp -r /var/www/* /backup/ # Backup web directory
cp -av src/ dest/ # Full directory copy with attributes
mv - Move/Rename Files
# Basic move/rename
mv old.txt new.txt # Rename file
mv file.txt dir/ # Move file to directory
mv file1 file2 dir/ # Move multiple files
mv -i file dest # Interactive mode
mv -v file dest # Verbose output
# Examples
mv *.log /var/log/ # Move all log files
mv -n file dest # No overwrite
mv --backup=numbered f d # Numbered backups
rm - Remove Files
# Basic removal
rm file.txt # Remove file
rm -r directory/ # Remove directory recursively
rm -f file # Force removal (no confirmation)
rm -rf directory/ # Force remove directory
rm -i file # Interactive (ask before removal)
rm -v file # Verbose output
# Safe practices
rm -I files* # Prompt once before removing many files
rm -d emptydir/ # Remove empty directory only
# Examples
rm *.tmp # Remove all .tmp files
rm -rf /tmp/session* # Force remove temp sessions
find . -name "*.bak" -delete # Alternative: safer removal
Warning: Use rm -rf with extreme caution!
mkdir - Make Directories
mkdir newdir # Create directory
mkdir -p path/to/dir # Create parent directories
mkdir -m 755 dir # Set permissions
mkdir -v dir # Verbose output
# Examples
mkdir -p project/{src,bin,doc} # Create multiple directories
mkdir -p ~/backup/$(date +%Y-%m-%d) # Date-based backup dir
touch - Create/Update Files
touch file.txt # Create empty file or update timestamp
touch -c file # No create (only update if exists)
touch -t 202301011200 file # Set specific timestamp
touch -d "2023-01-01" file # Set date
# Examples
touch {1..10}.txt # Create multiple files
touch -r ref.txt new.txt # Copy timestamp from reference
Text Processing
cat - Concatenate and Display
cat file.txt # Display file contents
cat file1 file2 # Concatenate multiple files
cat > file.txt # Create file from stdin (Ctrl+D to end)
cat >> file.txt # Append to file
cat -n file.txt # Number all lines
cat -b file.txt # Number non-blank lines
cat -s file.txt # Squeeze multiple blank lines
# Examples
cat /etc/passwd # View user accounts
cat file1 file2 > combined # Combine files
cat /dev/null > file.txt # Empty a file
grep - Search Text Patterns
# Basic search
grep "pattern" file.txt # Search for pattern
grep -i "pattern" file # Case-insensitive
grep -v "pattern" file # Invert match (exclude)
grep -r "pattern" dir/ # Recursive search
grep -n "pattern" file # Show line numbers
grep -c "pattern" file # Count matches
# Advanced options
grep -w "word" file # Match whole words only
grep -A 3 "pattern" file # Show 3 lines after match
grep -B 3 "pattern" file # Show 3 lines before match
grep -C 3 "pattern" file # Show 3 lines context
grep -l "pattern" files* # List filenames only
grep -E "regex" file # Extended regex (or egrep)
# Regular expressions
grep "^start" file # Lines starting with "start"
grep "end$" file # Lines ending with "end"
grep "^$" file # Empty lines
grep "[0-9]\{3\}" file # Three consecutive digits
# Examples
grep -r "TODO" ~/code/ # Find all TODOs in code
grep -i "error" /var/log/*.log # Find errors in logs
ps aux | grep nginx # Find nginx processes
grep -v "^#" config.txt # Show non-comment lines
netstat -tulpn | grep :80 # Find what's using port 80
Use Cases:
- Log file analysis
- Finding specific code patterns
- Filtering command output
- Configuration file parsing
sed - Stream Editor
# Basic substitution
sed 's/old/new/' file # Replace first occurrence per line
sed 's/old/new/g' file # Replace all occurrences
sed 's/old/new/gi' file # Case-insensitive global replace
sed -i 's/old/new/g' file # In-place editing
sed -i.bak 's/old/new/g' file # In-place with backup
# Line operations
sed -n '5p' file # Print line 5
sed -n '1,5p' file # Print lines 1-5
sed '5d' file # Delete line 5
sed '/pattern/d' file # Delete lines matching pattern
sed '1,3d' file # Delete lines 1-3
# Advanced usage
sed '/pattern/s/old/new/' file # Replace only in matching lines
sed 's/^/ /' file # Add 2 spaces at start of each line
sed 's/$/\r/' file # Convert to DOS line endings
sed '/^$/d' file # Remove empty lines
# Examples
sed 's/localhost/127.0.0.1/g' config # Replace hostname
sed -n '/ERROR/,/END/p' log # Print between patterns
sed '/#/d' file # Remove comment lines
sed 's/\t/ /g' file # Replace tabs with spaces
awk - Text Processing Language
# Basic usage
awk '{print}' file # Print all lines
awk '{print $1}' file # Print first column
awk '{print $1,$3}' file # Print columns 1 and 3
awk '{print $NF}' file # Print last column
awk '{print NR,$0}' file # Print line numbers
# Field separator
awk -F: '{print $1}' /etc/passwd # Custom delimiter
awk -F',' '{print $2}' data.csv # CSV parsing
# Patterns and conditions
awk '/pattern/' file # Print lines matching pattern
awk '$3 > 100' file # Print if column 3 > 100
awk 'NR==5' file # Print line 5
awk 'NR>=5 && NR<=10' file # Print lines 5-10
awk 'length($0) > 80' file # Print lines longer than 80 chars
# Calculations
awk '{sum+=$1} END {print sum}' file # Sum first column
awk '{print $1*$2}' file # Multiply columns 1 and 2
# Examples
awk -F: '{print $1}' /etc/passwd # List usernames
ps aux | awk '{print $2,$11}' # Print PID and command
df -h | awk '$5+0 > 80 {print $0}' # Disk usage > 80%
netstat -an | awk '/ESTABLISHED/ {print $5}' # Connected IPs
awk '{sum+=$1} END {print sum/NR}' data # Average of column 1
Use Cases:
- Log parsing and analysis
- Data extraction from structured text
- Quick calculations on columns
- Report generation
head - Display Beginning of File
head file.txt # First 10 lines
head -n 20 file.txt # First 20 lines
head -c 100 file.txt # First 100 bytes
head -n -5 file.txt # All but last 5 lines
# Examples
head -n 1 *.txt # First line of each file
head /var/log/syslog # Quick log preview
tail - Display End of File
tail file.txt # Last 10 lines
tail -n 20 file.txt # Last 20 lines
tail -f file.txt # Follow file (live updates)
tail -F file.txt # Follow with retry (if rotated)
tail -n +5 file.txt # From line 5 to end
# Examples
tail -f /var/log/syslog # Monitor system log
tail -n 100 -f app.log # Follow last 100 lines
tail -f log | grep ERROR # Filter live log stream
sort - Sort Lines
sort file.txt # Alphabetical sort
sort -r file.txt # Reverse sort
sort -n file.txt # Numeric sort
sort -u file.txt # Unique lines only
sort -k 2 file.txt # Sort by column 2
sort -t: -k3 -n /etc/passwd # Numeric sort by field 3
# Examples
sort -t',' -k2 -n data.csv # Sort CSV by second column
ls -l | sort -k 5 -n # Sort files by size
history | sort | uniq -c # Find most used commands
uniq - Report Unique Lines
uniq file.txt # Remove adjacent duplicates
uniq -c file.txt # Count occurrences
uniq -d file.txt # Show only duplicates
uniq -u file.txt # Show only unique lines
uniq -i file.txt # Case-insensitive
# Examples (usually with sort)
sort file.txt | uniq # Remove all duplicates
sort file.txt | uniq -c | sort -rn # Frequency count
Search and Find
find - Search for Files
# By name
find . -name "file.txt" # Find by exact name
find . -iname "*.txt" # Case-insensitive name
find /var -name "*.log" # Find in specific directory
# By type
find . -type f # Find files
find . -type d # Find directories
find . -type l # Find symbolic links
# By size
find . -size +100M # Files larger than 100MB
find . -size -1k # Files smaller than 1KB
find . -empty # Empty files/directories
# By time
find . -mtime -7 # Modified in last 7 days
find . -atime +30 # Accessed more than 30 days ago
find . -ctime -1 # Changed in last 24 hours
find . -mmin -60 # Modified in last 60 minutes
# By permissions
find . -perm 777 # Exactly 777 permissions
find . -perm -644 # At least 644 permissions
find . -user root # Owned by root
find . -group www-data # Owned by www-data group
# Actions
find . -name "*.tmp" -delete # Delete found files
find . -name "*.sh" -exec chmod +x {} \; # Execute command
find . -type f -exec wc -l {} + # Count lines
# Examples
find /home -user john -name "*.pdf" # User's PDF files
find . -name "*.log" -mtime +30 -delete # Delete old logs
find /var/www -type f -perm 777 # Find world-writable files
find . -size +50M -size -100M # Files between 50-100MB
find . -name "*.js" -exec grep -l "TODO" {} \; # Find TODOs
locate - Quick File Search
locate filename # Quick search in database
locate -i filename # Case-insensitive
locate -c pattern # Count matches
locate -b '\filename' # Exact basename match
# Update database
sudo updatedb # Refresh locate database
# Examples
locate nginx.conf # Find nginx config
locate -r '\.conf$' # All .conf files
which - Locate Command
which python # Find command path
which -a python # Show all matches
# Examples
which docker # Find Docker binary
type python # Alternative (bash builtin)
whereis - Locate Binary/Source/Manual
whereis ls # Find binary, source, man page
whereis -b ls # Binary only
whereis -m ls # Manual only
whereis -s ls # Source only
Process Management
ps - Process Status
# Basic usage
ps # Current shell processes
ps aux # All processes (BSD style)
ps -ef # All processes (System V style)
ps -u username # User's processes
ps -p 1234 # Specific process by PID
# Detailed view
ps aux | grep nginx # Find specific process
ps auxf # Process tree (forest)
ps -eo pid,user,%cpu,%mem,cmd # Custom columns
ps --sort=-%mem # Sort by memory usage
ps -C nginx # Processes by command name
# Examples
ps aux | head # Top processes
ps -eo pid,ppid,cmd,%mem,%cpu --sort=-%cpu | head # CPU hogs
ps -U www-data # Web server processes
top - Interactive Process Viewer
top # Launch interactive viewer
top -u username # Show user's processes
top -p 1234 # Monitor specific PID
top -b -n 1 # Batch mode (one iteration)
top -d 5 # Update every 5 seconds
# Interactive commands (while running)
# k - kill process
# r - renice (change priority)
# M - sort by memory
# P - sort by CPU
# q - quit
# h - help
# Examples
top -o %MEM # Sort by memory (macOS)
top | head -20 # First 20 lines
htop - Enhanced Process Viewer
htop # Launch htop (if installed)
htop -u username # Show user's processes
htop -p PID,PID # Monitor specific PIDs
# Interactive features
# F9 - kill process
# F7/F8 - adjust priority
# F5 - tree view
# F3 - search
# F4 - filter
kill - Terminate Process
# By PID
kill 1234 # Graceful termination (SIGTERM)
kill -9 1234 # Force kill (SIGKILL)
kill -15 1234 # Explicit SIGTERM
kill -HUP 1234 # Hangup signal (reload config)
# Signal list
kill -l # List all signals
# Examples
kill $(pidof firefox) # Kill by process name
killall nginx # Kill all nginx processes
pkill -u username # Kill user's processes
pkill/killall - Kill by Name
pkill firefox # Kill by process name
pkill -u username # Kill user's processes
pkill -9 python # Force kill all python processes
pkill -f "script.py" # Kill by full command line
killall nginx # Kill all nginx processes
killall -u username bash # Kill user's bash sessions
jobs/bg/fg - Job Control
# Job control
command & # Run in background
jobs # List background jobs
fg %1 # Bring job 1 to foreground
bg %1 # Resume job 1 in background
Ctrl+Z # Suspend current job
disown %1 # Detach job from shell
# Examples
find / -name "*.log" > /tmp/logs.txt & # Background search
sleep 100 & # Background sleep
jobs -l # List with PIDs
nohup - Run Immune to Hangups
nohup command & # Run detached from terminal
nohup ./script.sh & # Script continues after logout
nohup command > output.log 2>&1 & # Redirect output
# Examples
nohup python server.py > server.log 2>&1 &
nohup long_running_task.sh &
systemctl - Service Management
# Service operations
systemctl start nginx # Start service
systemctl stop nginx # Stop service
systemctl restart nginx # Restart service
systemctl reload nginx # Reload configuration
systemctl status nginx # Service status
systemctl enable nginx # Enable at boot
systemctl disable nginx # Disable at boot
# System operations
systemctl reboot # Reboot system
systemctl poweroff # Shutdown system
systemctl suspend # Suspend system
# Information
systemctl list-units # List active units
systemctl list-unit-files # List all unit files
systemctl --failed # Show failed services
systemctl is-enabled nginx # Check if enabled
systemctl is-active nginx # Check if running
# Examples
systemctl status sshd # Check SSH status
systemctl restart apache2 # Restart web server
systemctl list-dependencies nginx # Show dependencies
System Monitoring
df - Disk Free Space
df # Show disk usage
df -h # Human-readable sizes
df -i # Inode usage
df -T # Show filesystem type
df /home # Specific mount point
# Examples
df -h | grep -v tmpfs # Exclude temporary filesystems
df -h --total # Show total summary
du - Disk Usage
du # Directory space usage
du -h # Human-readable
du -sh * # Summary for each item
du -sh directory # Total for directory
du -ah # All files (not just directories)
du --max-depth=1 # Limit directory depth
# Examples
du -sh /var/log # Log directory size
du -h | sort -rh | head -10 # Top 10 largest directories
du -ch *.log | tail -1 # Total size of log files
free - Memory Usage
free # Show memory usage
free -h # Human-readable
free -m # In megabytes
free -g # In gigabytes
free -s 5 # Update every 5 seconds
# Examples
free -h # Quick memory check
watch -n 1 free -h # Monitor continuously
vmstat - Virtual Memory Statistics
vmstat # Memory, process, paging stats
vmstat 1 # Update every second
vmstat 1 10 # 10 samples, 1 second apart
vmstat -s # Memory statistics
vmstat -d # Disk statistics
# Examples
vmstat 5 # Monitor system stats
iostat - I/O Statistics
iostat # CPU and disk I/O stats
iostat -x # Extended statistics
iostat -d 1 # Disk stats every second
iostat -p sda # Specific disk
# Examples
iostat -xz 1 # Extended, skip zero-activity
netstat - Network Statistics
netstat -tulpn # Listening ports with programs
netstat -an # All connections, numeric
netstat -r # Routing table
netstat -i # Network interfaces
netstat -s # Protocol statistics
# Examples
netstat -tulpn | grep :80 # Check port 80
netstat -ant | grep ESTABLISHED # Active connections
ss - Socket Statistics (newer netstat)
ss -tulpn # Listening TCP/UDP ports
ss -ta # All TCP sockets
ss -ua # All UDP sockets
ss -s # Summary statistics
ss dst :80 # Connections to port 80
# Examples
ss -t state established # Established TCP connections
ss -o state established # With timer info
ss -p | grep ssh # SSH connections
lsof - List Open Files
lsof # All open files
lsof -u username # User's open files
lsof -i :80 # Processes using port 80
lsof -i TCP:1-1024 # Processes on ports 1-1024
lsof /path/to/file # What's accessing a file
lsof -c nginx # Files opened by nginx
lsof -p 1234 # Files opened by PID
# Examples
lsof -i -P -n # Network connections (no DNS)
lsof +D /var/log # Everything under directory
lsof -t -i :8080 # PIDs using port 8080
User Management
useradd - Create User
useradd username # Create user
useradd -m username # Create with home directory
useradd -m -s /bin/bash username # Specify shell
useradd -m -G group1,group2 user # Add to groups
useradd -m -e 2024-12-31 user # With expiry date
# Examples
useradd -m -s /bin/bash john
useradd -m -G sudo,docker admin
usermod - Modify User
usermod -aG sudo username # Add to sudo group
usermod -s /bin/zsh user # Change shell
usermod -L username # Lock account
usermod -U username # Unlock account
usermod -e 2024-12-31 user # Set expiry date
# Examples
usermod -aG docker username # Add to docker group
usermod -d /new/home -m user # Change home directory
userdel - Delete User
userdel username # Delete user
userdel -r username # Delete user and home directory
passwd - Change Password
passwd # Change your password
passwd username # Change user's password (as root)
passwd -l username # Lock password
passwd -u username # Unlock password
passwd -e username # Expire password (force change)
# Examples
passwd john # Set password for john
passwd -S john # Show password status
su - Switch User
su # Switch to root
su username # Switch to user
su - username # Switch with environment
su -c "command" username # Run command as user
# Examples
su - postgres # Switch to postgres user
su -c "systemctl restart nginx" root
sudo - Execute as Superuser
sudo command # Run command as root
sudo -u user command # Run as specific user
sudo -i # Interactive root shell
sudo -s # Shell as root
sudo -l # List allowed commands
sudo -k # Invalidate cached credentials
# Examples
sudo apt update # Update package lists
sudo -u www-data touch /var/www/file
sudo !! # Run last command with sudo
Permissions
chmod - Change File Mode
# Numeric mode
chmod 644 file # rw-r--r--
chmod 755 file # rwxr-xr-x
chmod 777 file # rwxrwxrwx
chmod 600 file # rw-------
# Symbolic mode
chmod u+x file # Add execute for user
chmod g-w file # Remove write for group
chmod o=r file # Set others to read only
chmod a+r file # Add read for all
chmod u+x,g+x file # Multiple changes
# Recursive
chmod -R 755 directory # Apply recursively
# Examples
chmod +x script.sh # Make executable
chmod -R 755 /var/www # Web directory permissions
chmod u+s file # Set SUID bit
chmod g+s directory # Set SGID bit
chmod +t directory # Set sticky bit
Permission numbers:
- 4 = read (r)
- 2 = write (w)
- 1 = execute (x)
- Sum for each user/group/others
chown - Change Ownership
chown user file # Change owner
chown user:group file # Change owner and group
chown -R user:group dir # Recursive change
chown --reference=ref file # Copy ownership from reference
# Examples
chown www-data:www-data /var/www/html
chown -R mysql:mysql /var/lib/mysql
chown john:developers project/
chgrp - Change Group
chgrp group file # Change group
chgrp -R group directory # Recursive change
# Examples
chgrp www-data website/
chgrp -R developers /opt/project
umask - Default Permissions
umask # Show current umask
umask 022 # Set umask (755 for dirs, 644 for files)
umask 002 # Set umask (775 for dirs, 664 for files)
# Examples
umask 077 # Private by default (700/600)
Package Management
APT (Debian/Ubuntu)
# Update
apt update # Update package lists
apt upgrade # Upgrade packages
apt full-upgrade # Upgrade + handle dependencies
apt dist-upgrade # Distribution upgrade
# Install/Remove
apt install package # Install package
apt install package1 package2 # Multiple packages
apt remove package # Remove package
apt purge package # Remove package and config
apt autoremove # Remove unused dependencies
# Search and Info
apt search keyword # Search packages
apt show package # Package information
apt list --installed # List installed packages
apt list --upgradable # List upgradable packages
# Examples
apt install nginx # Install web server
apt remove --purge apache2 # Complete removal
apt install build-essential git curl
DNF/YUM (RHEL/Fedora/CentOS)
# Update
dnf update # Update packages
dnf upgrade # Synonym for update
# Install/Remove
dnf install package # Install package
dnf remove package # Remove package
dnf autoremove # Remove orphaned dependencies
# Search and Info
dnf search keyword # Search packages
dnf info package # Package information
dnf list installed # List installed packages
# Examples
dnf install httpd # Install Apache
dnf groupinstall "Development Tools"
Snap (Universal)
snap install package # Install snap package
snap remove package # Remove package
snap refresh # Update all snaps
snap list # List installed snaps
snap find keyword # Search snaps
# Examples
snap install docker
snap install --classic code # Classic confinement
Network Commands
ip - Network Configuration
# Address management
ip addr show # Show all IP addresses
ip addr show eth0 # Show specific interface
ip addr add IP/MASK dev eth0 # Add IP address
ip addr del IP/MASK dev eth0 # Delete IP address
# Link management
ip link show # Show network interfaces
ip link set eth0 up # Bring interface up
ip link set eth0 down # Bring interface down
# Route management
ip route show # Show routing table
ip route add default via GATEWAY # Add default route
ip route del default # Delete default route
# Neighbor (ARP)
ip neigh show # Show ARP table
# Examples
ip addr show # Quick network overview
ip route get 8.8.8.8 # Show route to destination
ip link set eth0 mtu 9000 # Set MTU
ping - Test Connectivity
ping host # Ping host
ping -c 4 host # Send 4 packets
ping -i 2 host # 2 second interval
ping -s 1000 host # 1000 byte packets
ping -W 1 host # 1 second timeout
# Examples
ping -c 4 google.com # Test internet connectivity
ping 192.168.1.1 # Test local gateway
curl - Transfer Data
# Basic requests
curl URL # GET request
curl -O URL # Download file (keep name)
curl -o file.txt URL # Download with custom name
curl -I URL # Headers only
curl -L URL # Follow redirects
# HTTP methods
curl -X POST URL # POST request
curl -X PUT URL # PUT request
curl -X DELETE URL # DELETE request
# Data and headers
curl -d "param=value" URL # POST data
curl -H "Header: Value" URL # Custom header
curl -u user:pass URL # Basic authentication
curl -b cookies.txt URL # Send cookies
curl -c cookies.txt URL # Save cookies
# Examples
curl -I https://google.com # Check headers
curl -o page.html https://example.com
curl -X POST -H "Content-Type: application/json" -d '{"key":"value"}' API_URL
curl -u admin:password http://api.example.com
wget - Download Files
wget URL # Download file
wget -O filename URL # Save with custom name
wget -c URL # Continue interrupted download
wget -b URL # Background download
wget -r URL # Recursive download
wget --limit-rate=200k URL # Limit download speed
wget -i urls.txt # Download multiple URLs
# Examples
wget https://example.com/file.iso
wget -c https://mirrors.kernel.org/ubuntu/ubuntu-22.04.iso
wget -r -np -k http://example.com # Mirror website
ssh - Secure Shell
ssh user@host # Connect to host
ssh -p 2222 user@host # Custom port
ssh -i key.pem user@host # Use specific key
ssh user@host command # Run remote command
ssh -L 8080:localhost:80 user@host # Local port forwarding
ssh -R 8080:localhost:80 user@host # Remote port forwarding
# Examples
ssh john@192.168.1.100
ssh -i ~/.ssh/aws-key.pem ubuntu@ec2-instance
ssh user@host 'df -h' # Check remote disk space
scp - Secure Copy
scp file user@host:/path # Copy to remote
scp user@host:/path/file . # Copy from remote
scp -r directory user@host:/path # Copy directory
scp -P 2222 file user@host:/path # Custom port
scp -i key.pem file user@host:/path # Specific key
# Examples
scp backup.tar.gz user@backup-server:/backups/
scp -r website/ user@server:/var/www/
scp user@server:/var/log/app.log ./logs/
rsync - Sync Files
rsync -av source/ dest/ # Archive and verbose
rsync -avz source/ user@host:dest/ # With compression
rsync -av --delete src/ dst/ # Delete in destination
rsync -av --progress src/ dst/ # Show progress
rsync -av --exclude="*.log" src/ dst/ # Exclude pattern
# Examples
rsync -avz ~/project/ backup-server:/backups/project/
rsync -av --delete /var/www/ /backup/www/
rsync -avz -e "ssh -p 2222" src/ user@host:dest/
Service Management
journalctl - Query Systemd Journal
journalctl # Show all logs
journalctl -f # Follow logs (tail -f)
journalctl -u nginx # Service logs
journalctl -u nginx -f # Follow service logs
journalctl --since today # Today's logs
journalctl --since "1 hour ago" # Last hour
journalctl -p err # Error priority and above
journalctl -k # Kernel messages
journalctl -b # Current boot logs
journalctl --disk-usage # Disk usage by logs
# Examples
journalctl -u sshd -n 100 # Last 100 SSH log entries
journalctl --since "2024-01-01" --until "2024-01-31"
journalctl -u nginx --since yesterday
Compression
tar - Archive Files
# Create archives
tar -cvf archive.tar files # Create tar archive
tar -czvf archive.tar.gz files # Create gzipped archive
tar -cjvf archive.tar.bz2 files # Create bzip2 archive
tar -cJvf archive.tar.xz files # Create xz archive
# Extract archives
tar -xvf archive.tar # Extract tar
tar -xzvf archive.tar.gz # Extract gzipped
tar -xjvf archive.tar.bz2 # Extract bzip2
tar -xJvf archive.tar.xz # Extract xz
tar -xzvf archive.tar.gz -C /dest # Extract to directory
# List contents
tar -tvf archive.tar # List contents
tar -tzvf archive.tar.gz # List gzipped archive
# Examples
tar -czvf backup-$(date +%Y%m%d).tar.gz /home/user/
tar -xzvf website.tar.gz -C /var/www/
tar -czvf project.tar.gz --exclude='*.log' project/
gzip/gunzip - Compress Files
gzip file.txt # Compress (creates file.txt.gz)
gzip -k file.txt # Keep original
gzip -9 file.txt # Maximum compression
gunzip file.txt.gz # Decompress
gzip -l file.txt.gz # List compression info
# Examples
gzip -r directory/ # Compress all files in directory
gzip -c file.txt > file.txt.gz # Keep original
zip/unzip - Zip Archives
zip archive.zip files # Create zip
zip -r archive.zip dir/ # Recursive zip
unzip archive.zip # Extract zip
unzip -l archive.zip # List contents
unzip archive.zip -d /dest # Extract to directory
# Examples
zip -r backup.zip /home/user/Documents
unzip file.zip
zip -e secure.zip file # Password protected
Disk Management
fdisk - Partition Disk
fdisk -l # List all disks and partitions
fdisk /dev/sda # Open disk for partitioning
# Interactive commands (in fdisk):
# n - new partition
# d - delete partition
# p - print partition table
# w - write changes
# q - quit without saving
mount/umount - Mount Filesystems
mount # Show mounted filesystems
mount /dev/sda1 /mnt # Mount partition
mount -t nfs server:/share /mnt # Mount NFS
mount -o loop disk.iso /mnt # Mount ISO
umount /mnt # Unmount
umount -l /mnt # Lazy unmount
# Examples
mount /dev/sdb1 /media/usb
mount -t cifs //server/share /mnt -o username=user
mount --bind /source /dest # Bind mount
mkfs - Make Filesystem
mkfs.ext4 /dev/sda1 # Create ext4 filesystem
mkfs.xfs /dev/sda1 # Create XFS filesystem
mkfs.vfat /dev/sda1 # Create FAT filesystem
# Examples
mkfs.ext4 -L MyDisk /dev/sdb1 # With label
mkfs.ext4 -m 1 /dev/sdb1 # Reserve 1% for root
System Information
uname - System Information
uname -a # All information
uname -r # Kernel release
uname -m # Machine hardware
uname -o # Operating system
lscpu - CPU Information
lscpu # Detailed CPU info
lscpu | grep "CPU(s)" # Number of CPUs
lsblk - Block Devices
lsblk # List block devices
lsblk -f # Show filesystems
lsblk -o NAME,SIZE,TYPE,MOUNTPOINT # Custom columns
lspci - PCI Devices
lspci # List PCI devices
lspci -v # Verbose output
lspci | grep VGA # Graphics card info
lsusb - USB Devices
lsusb # List USB devices
lsusb -v # Verbose output
hostname - System Hostname
hostname # Show hostname
hostname -I # Show IP addresses
hostnamectl # Detailed host information
hostnamectl set-hostname new-name # Change hostname
date - Date and Time
date # Current date and time
date +%Y-%m-%d # Custom format (2024-01-15)
date +%s # Unix timestamp
date -d "yesterday" # Yesterday's date
date -d "next Monday" # Next Monday
# Examples
date +%Y%m%d-%H%M%S # 20240115-143025
date -d @1704067200 # Convert timestamp
uptime - System Uptime
uptime # How long system is running
uptime -p # Pretty format
uptime -s # Since when
Practical Tips and Best Practices
Command Chaining
# Sequential execution
command1 ; command2 # Run both regardless
command1 && command2 # Run command2 if command1 succeeds
command1 || command2 # Run command2 if command1 fails
# Examples
apt update && apt upgrade # Update then upgrade
make || echo "Build failed"
cd /tmp && rm -rf old_files
Redirection and Pipes
# Output redirection
command > file # Overwrite file
command >> file # Append to file
command 2> file # Redirect stderr
command > file 2>&1 # Redirect both stdout and stderr
command &> file # Redirect both (shorthand)
# Input redirection
command < file # Read from file
command << EOF # Here document
multiline input
This section will provide an overview of netfilter and its role in packet filtering.
Netfilter
Netfilter is a framework provided by the Linux kernel for packet filtering, network address translation (NAT), and other packet mangling. It allows system administrators to define rules for how packets should be handled by the kernel.
Key Concepts
-
Hooks: Netfilter provides hooks in the networking stack where packets can be intercepted and processed. The main hooks are:
- PREROUTING: Before routing decisions are made.
- INPUT: For packets destined for the local system.
- FORWARD: For packets being routed through the system.
- OUTPUT: For packets generated by the local system.
- POSTROUTING: After routing decisions are made.
-
Tables: Netfilter organizes rules into tables, with the most common being:
- filter: The default table for packet filtering.
- nat: Used for network address translation.
- mangle: Used for specialized packet alterations.
-
Chains: Each table contains chains, which are lists of rules that packets are checked against. Each rule specifies a target action (e.g., ACCEPT, DROP) when a packet matches.
Common Commands
-
List Rules: To view the current rules in a specific table, use:
iptables -L -
Add a Rule: To add a new rule to a chain, use:
iptables -A INPUT -p tcp --dport 80 -j ACCEPT -
Delete a Rule: To delete a specific rule, use:
iptables -D INPUT -p tcp --dport 80 -j ACCEPT -
Save Rules: To save the current rules to a file, use:
iptables-save > /etc/iptables/rules.v4
Applications
Netfilter is widely used for:
- Firewalling: Protecting systems from unauthorized access and attacks.
- NAT: Allowing multiple devices on a local network to share a single public IP address.
- Traffic Shaping: Managing and controlling the flow of network traffic.
Conclusion
Netfilter is a crucial component of the Linux networking stack, providing powerful capabilities for packet filtering and manipulation. Understanding how to configure and use netfilter effectively is essential for system administrators and network engineers.
tc
tc (traffic control) is a utility in the Linux kernel used to configure Traffic Control in the network stack. It allows administrators to configure the queuing discipline (qdisc), which determines how packets are enqueued and dequeued from the network interface.
Important Components of tc
- qdisc (Queuing Discipline): The core component of
tc, which defines the algorithm used to manage the packet queue. Examples includepfifo_fast,fq_codel, andnetem. - class: A way to create a hierarchy within a qdisc, allowing for more granular control over traffic. Classes can be used to apply different rules to different types of traffic.
- filter: Used to classify packets into different classes. Filters can match on various packet attributes, such as IP address, port number, or protocol.
- action: Defines what to do with packets that match a filter. Actions can include marking, mirroring, or redirecting packets.
Uses of tc
- Traffic Shaping: Control the rate of outgoing traffic to ensure that the network is not overwhelmed. This can be useful for managing bandwidth usage and ensuring fair distribution of network resources.
- Traffic Policing: Enforce limits on the rate of incoming traffic, dropping packets that exceed the specified rate. This can help protect against network abuse or attacks.
- Network Emulation: Simulate various network conditions, such as latency, packet loss, and jitter, to test the performance of applications under different scenarios.
- Quality of Service (QoS): Prioritize certain types of traffic to ensure that critical applications receive the necessary bandwidth and low latency.
By using tc, administrators can fine-tune network performance, improve reliability, and ensure that critical applications have the necessary resources to function optimally.
Add delay to all traffic on eth0
sudo tc qdisc add dev eth0 root netem delay 100ms
iptables
iptables is a user-space utility program that allows a system administrator to configure the IP packet filter rules of the Linux kernel firewall. It is a powerful tool for managing network traffic and enhancing security.
Key Concepts
-
Chains: A chain is a set of rules that iptables uses to determine the action to take on packets. There are three built-in chains: INPUT, OUTPUT, and FORWARD.
-
Tables: iptables organizes rules into tables, with the most common being the filter table, which is used for packet filtering.
-
Targets: Each rule in a chain specifies a target, which is the action to take when a packet matches the rule. Common targets include ACCEPT, DROP, and REJECT.
Common Commands
-
List Rules: To view the current rules in a specific chain, use:
iptables -L -
Add a Rule: To add a new rule to a chain, use:
iptables -A INPUT -s 192.168.1.1 -j ACCEPT -
Delete a Rule: To delete a specific rule, use:
iptables -D INPUT -s 192.168.1.1 -j ACCEPT -
Save Rules: To save the current rules to a file, use:
iptables-save > /etc/iptables/rules.v4
Applications
iptables is widely used for:
- Network Security: Protecting systems from unauthorized access and attacks.
- Traffic Control: Managing and controlling the flow of network traffic.
- Logging: Keeping track of network activity for analysis and troubleshooting.
Conclusion
iptables is an essential tool for network management and security in Linux environments. Understanding how to configure and use iptables effectively is crucial for system administrators and network engineers.
ELI10: What is iptables?
iptables is like a set of rules for your computer's door. Just like you might have rules about who can come into your house or what they can bring, iptables helps your computer decide what kind of data can come in or go out.
Here’s a simple breakdown:
-
Chains: Think of these as different doors. Each door has its own set of rules. For example, one door might let in friends (INPUT), another might let out toys (OUTPUT), and a third might let things pass through without stopping (FORWARD).
-
Tables: These are like the lists of rules for each door. The most common list is for filtering, which decides what gets to come in or go out.
-
Targets: When something tries to come through a door, the rules tell it what to do. It might be allowed in (ACCEPT), told to go away (DROP), or asked to leave a message (REJECT).
So, iptables is a way to keep your computer safe and make sure only the right data gets in and out!
Example Commands
-
List Rules: To see what rules are set up, you can use:
iptables -L -
Add a Rule: If you want to let a specific friend in, you can add a rule like this:
iptables -A INPUT -s 192.168.1.1 -j ACCEPT -
Delete a Rule: If you want to remove a rule, you can do it like this:
iptables -D INPUT -s 192.168.1.1 -j ACCEPT -
Save Rules: To keep your rules safe, you can save them to a file:
iptables-save > /etc/iptables/rules.v4
Why Use iptables?
Using iptables helps keep your computer safe from bad data and makes sure everything runs smoothly. It's like having a good security system for your digital home!
systemd
systemd is a system and service manager for Linux operating systems. It provides aggressive parallelization capabilities, uses socket and D-Bus activation for starting services, offers on-demand starting of daemons, and maintains process tracking using Linux control groups.
Overview
systemd replaces the traditional SysV init system and provides a more modern approach to system initialization and service management.
Key Features:
- Parallel service startup
- Socket and D-Bus activation
- On-demand service starting
- Process supervision
- Mount and automount point management
- Snapshot support
- System state snapshots
- Logging with journald
Basic Concepts
Units: Resources that systemd manages
- Service units (.service): System services
- Socket units (.socket): IPC or network sockets
- Target units (.target): Group of units (like runlevels)
- Mount units (.mount): Mount points
- Timer units (.timer): Scheduled tasks
- Device units (.device): Device files
- Path units (.path): File/directory monitoring
Service Management
systemctl Commands
# Service control
sudo systemctl start service_name
sudo systemctl stop service_name
sudo systemctl restart service_name
sudo systemctl reload service_name # Reload config without restart
sudo systemctl reload-or-restart service_name
# Enable/disable services (start at boot)
sudo systemctl enable service_name
sudo systemctl disable service_name
sudo systemctl enable --now service_name # Enable and start
# Check service status
systemctl status service_name
systemctl is-active service_name
systemctl is-enabled service_name
systemctl is-failed service_name
# List services
systemctl list-units --type=service
systemctl list-units --type=service --state=running
systemctl list-units --type=service --state=failed
systemctl list-unit-files --type=service
# Show service configuration
systemctl cat service_name
systemctl show service_name
# Service dependencies
systemctl list-dependencies service_name
Service Examples
# Common services
sudo systemctl status nginx
sudo systemctl restart sshd
sudo systemctl enable docker
sudo systemctl start postgresql
# Check all failed services
systemctl --failed
# Mask service (prevent from being started)
sudo systemctl mask service_name
sudo systemctl unmask service_name
Creating Service Units
Basic Service File
# /etc/systemd/system/myapp.service
[Unit]
Description=My Application
After=network.target
Wants=network-online.target
[Service]
Type=simple
User=myapp
Group=myapp
WorkingDirectory=/opt/myapp
ExecStart=/opt/myapp/bin/myapp
Restart=on-failure
RestartSec=5s
[Install]
WantedBy=multi-user.target
Service Types
# Type=simple (default)
[Service]
Type=simple
ExecStart=/usr/bin/myapp
# Type=forking (daemon that forks)
[Service]
Type=forking
PIDFile=/var/run/myapp.pid
ExecStart=/usr/bin/myapp --daemon
# Type=oneshot (runs once and exits)
[Service]
Type=oneshot
ExecStart=/usr/bin/backup-script.sh
RemainAfterExit=yes
# Type=notify (sends notification when ready)
[Service]
Type=notify
ExecStart=/usr/bin/myapp
NotifyAccess=main
# Type=dbus (acquires D-Bus name)
[Service]
Type=dbus
BusName=org.example.myapp
ExecStart=/usr/bin/myapp
# Type=idle (delays until all jobs finished)
[Service]
Type=idle
ExecStart=/usr/bin/myapp
Advanced Service Configuration
# /etc/systemd/system/myapp.service
[Unit]
Description=My Web Application
Documentation=https://example.com/docs
After=network-online.target postgresql.service
Wants=network-online.target
Requires=postgresql.service
[Service]
Type=notify
User=www-data
Group=www-data
WorkingDirectory=/opt/myapp
# Environment
Environment="NODE_ENV=production"
Environment="PORT=3000"
EnvironmentFile=/etc/myapp/config
# Execution
ExecStartPre=/usr/bin/myapp-check-config
ExecStart=/usr/bin/node /opt/myapp/server.js
ExecReload=/bin/kill -HUP $MAINPID
ExecStop=/bin/kill -TERM $MAINPID
# Restart policy
Restart=on-failure
RestartSec=5s
StartLimitInterval=10min
StartLimitBurst=5
# Security
PrivateTmp=true
NoNewPrivileges=true
ProtectSystem=strict
ProtectHome=true
ReadWritePaths=/var/lib/myapp
ReadWritePaths=/var/log/myapp
# Resource limits
LimitNOFILE=65536
MemoryLimit=1G
CPUQuota=200%
# Logging
StandardOutput=journal
StandardError=journal
SyslogIdentifier=myapp
[Install]
WantedBy=multi-user.target
Service Management Workflow
# Create service file
sudo vim /etc/systemd/system/myapp.service
# Reload systemd configuration
sudo systemctl daemon-reload
# Enable and start service
sudo systemctl enable --now myapp
# Check status
systemctl status myapp
# View logs
journalctl -u myapp -f
# Edit service (creates override)
sudo systemctl edit myapp
# Edit full service file
sudo systemctl edit --full myapp
Timers (Cron Alternative)
Timer Unit
# /etc/systemd/system/backup.timer
[Unit]
Description=Daily Backup Timer
Requires=backup.service
[Timer]
OnCalendar=daily
OnCalendar=*-*-* 02:00:00
Persistent=true
Unit=backup.service
[Install]
WantedBy=timers.target
Corresponding Service
# /etc/systemd/system/backup.service
[Unit]
Description=Backup Service
[Service]
Type=oneshot
ExecStart=/usr/local/bin/backup.sh
User=backup
Timer Management
# Enable and start timer
sudo systemctl enable --now backup.timer
# List timers
systemctl list-timers
systemctl list-timers --all
# Check timer status
systemctl status backup.timer
# View next run time
systemctl list-timers backup.timer
# Manual trigger
sudo systemctl start backup.service
Timer Examples
# Every 5 minutes
OnCalendar=*:0/5
# Every hour
OnCalendar=hourly
# Every day at 3:00 AM
OnCalendar=*-*-* 03:00:00
# Every Monday at 9:00 AM
OnCalendar=Mon *-*-* 09:00:00
# First day of month
OnCalendar=*-*-01 00:00:00
# Relative to boot
OnBootSec=15min
OnUnitActiveSec=1h
journalctl (Logging)
Viewing Logs
# View all logs
journalctl
# Follow logs (like tail -f)
journalctl -f
# Recent logs
journalctl -n 50 # Last 50 lines
journalctl -n 100 --no-pager
# Service-specific logs
journalctl -u nginx
journalctl -u nginx -f
journalctl -u nginx --since today
# Multiple services
journalctl -u nginx -u postgresql
# Time-based filtering
journalctl --since "2024-01-01"
journalctl --since "2024-01-01 10:00" --until "2024-01-01 11:00"
journalctl --since "1 hour ago"
journalctl --since yesterday
journalctl --since "10 min ago"
# Priority filtering
journalctl -p err # Errors only
journalctl -p warning # Warnings and above
journalctl -p 0..3 # Emergency to error
# Kernel messages
journalctl -k
journalctl -k -b # Current boot
# Boot-specific logs
journalctl -b # Current boot
journalctl -b -1 # Previous boot
journalctl --list-boots # List all boots
# Specific process
journalctl _PID=1234
# Output formats
journalctl -o json # JSON format
journalctl -o json-pretty
journalctl -o verbose
journalctl -o cat # Just the message
# Disk usage
journalctl --disk-usage
# Verify integrity
journalctl --verify
Journal Management
# Clean old logs
sudo journalctl --vacuum-time=7d # Keep last 7 days
sudo journalctl --vacuum-size=500M # Keep max 500MB
sudo journalctl --vacuum-files=5 # Keep max 5 files
# Rotate journals
sudo systemctl kill --signal=SIGUSR2 systemd-journald
# Configure retention
# /etc/systemd/journald.conf
[Journal]
SystemMaxUse=500M
SystemMaxFileSize=100M
SystemMaxFiles=5
RuntimeMaxUse=100M
MaxRetentionSec=7day
Targets (Runlevels)
Common Targets
# List targets
systemctl list-units --type=target
# Current target
systemctl get-default
# Change default target
sudo systemctl set-default multi-user.target
sudo systemctl set-default graphical.target
# Switch target
sudo systemctl isolate multi-user.target
sudo systemctl isolate rescue.target
# Common targets
# poweroff.target (runlevel 0)
# rescue.target (runlevel 1)
# multi-user.target (runlevel 3)
# graphical.target (runlevel 5)
# reboot.target (runlevel 6)
System Management
System Control
# Reboot/shutdown
sudo systemctl reboot
sudo systemctl poweroff
sudo systemctl halt
sudo systemctl suspend
sudo systemctl hibernate
sudo systemctl hybrid-sleep
# System state
systemctl is-system-running
# Reload systemd configuration
sudo systemctl daemon-reload
# Reexecute systemd
sudo systemctl daemon-reexec
# Show system boot time
systemd-analyze
systemd-analyze blame # Show service startup times
systemd-analyze critical-chain # Show critical startup chain
systemd-analyze plot > boot.svg # Generate SVG timeline
# List all units
systemctl list-units
systemctl list-units --all
systemctl list-unit-files
# Check configuration
sudo systemd-analyze verify /etc/systemd/system/myapp.service
Socket Activation
# /etc/systemd/system/myapp.socket
[Unit]
Description=My App Socket
[Socket]
ListenStream=8080
Accept=no
[Install]
WantedBy=sockets.target
# /etc/systemd/system/myapp.service
[Unit]
Description=My App Service
Requires=myapp.socket
[Service]
ExecStart=/usr/bin/myapp
StandardInput=socket
Path Units (File Monitoring)
# /etc/systemd/system/watch-config.path
[Unit]
Description=Watch Config Directory
[Path]
PathModified=/etc/myapp
Unit=process-config.service
[Install]
WantedBy=multi-user.target
# /etc/systemd/system/process-config.service
[Unit]
Description=Process Config Changes
[Service]
Type=oneshot
ExecStart=/usr/local/bin/reload-config.sh
User Services
# User service directory
~/.config/systemd/user/
# User commands (no sudo)
systemctl --user start myservice
systemctl --user enable myservice
systemctl --user status myservice
# User timers
systemctl --user list-timers
# Enable lingering (services run without login)
loginctl enable-linger username
# User journal
journalctl --user
journalctl --user -u myservice
Example User Service
# ~/.config/systemd/user/myapp.service
[Unit]
Description=My User Application
[Service]
ExecStart=%h/bin/myapp
Restart=on-failure
[Install]
WantedBy=default.target
Security Features
Service Hardening
[Service]
# User/Group isolation
User=myapp
Group=myapp
DynamicUser=yes # Create temporary user
# Filesystem restrictions
ProtectSystem=strict # Read-only /usr, /boot, /efi
ProtectHome=true # Inaccessible /home
PrivateTmp=true # Private /tmp
ReadWritePaths=/var/lib/myapp # Writable paths
ReadOnlyPaths=/etc/myapp
InaccessiblePaths=/root
# Namespace isolation
PrivateDevices=yes # Private /dev
PrivateNetwork=yes # Private network namespace
PrivateUsers=yes # User namespace
# Capabilities
NoNewPrivileges=yes # Prevent privilege escalation
CapabilityBoundingSet=CAP_NET_BIND_SERVICE
AmbientCapabilities=CAP_NET_BIND_SERVICE
# System calls
SystemCallFilter=@system-service
SystemCallFilter=~@privileged @resources
SystemCallErrorNumber=EPERM
# Misc restrictions
RestrictAddressFamilies=AF_INET AF_INET6
RestrictNamespaces=yes
RestrictRealtime=yes
LockPersonality=yes
ProtectKernelTunables=yes
ProtectKernelModules=yes
ProtectControlGroups=yes
MemoryDenyWriteExecute=yes
Troubleshooting
Common Issues
# Service won't start
systemctl status service_name
journalctl -u service_name -n 50
journalctl -xe
# Check service configuration
systemd-analyze verify /etc/systemd/system/myapp.service
# Dependency issues
systemctl list-dependencies service_name
systemctl list-dependencies --reverse service_name
# Stuck service
sudo systemctl kill service_name
sudo systemctl kill -s SIGKILL service_name
# Reset failed state
sudo systemctl reset-failed service_name
sudo systemctl reset-failed
# Show why service failed
systemctl status service_name --no-pager --full
# Debug mode
sudo SYSTEMD_LOG_LEVEL=debug systemctl start service_name
# Emergency shell
# Add to kernel command line: systemd.unit=emergency.target
Debugging Services
# Add debug output to service
[Service]
Environment="DEBUG=true"
StandardOutput=journal+console
StandardError=journal+console
# Increase log level
LogLevel=debug
# Show environment
systemctl show-environment
systemctl show service_name
Best Practices
# 1. Use After= and Wants= for dependencies
[Unit]
After=network-online.target
Wants=network-online.target
# 2. Set restart policy
[Service]
Restart=on-failure
RestartSec=5s
StartLimitInterval=10min
StartLimitBurst=5
# 3. Use specific user
[Service]
User=myapp
Group=myapp
# 4. Set working directory
[Service]
WorkingDirectory=/opt/myapp
# 5. Use environment files
[Service]
EnvironmentFile=/etc/myapp/config
# 6. Add security restrictions
[Service]
ProtectSystem=strict
PrivateTmp=true
NoNewPrivileges=true
# 7. Proper logging
[Service]
StandardOutput=journal
StandardError=journal
SyslogIdentifier=myapp
# 8. Resource limits
[Service]
LimitNOFILE=65536
MemoryMax=1G
# 9. Use timers instead of cron
# Create .timer and .service files
# 10. Test configuration
sudo systemd-analyze verify myapp.service
Quick Reference
Service Management
| Command | Description |
|---|---|
systemctl start | Start service |
systemctl stop | Stop service |
systemctl restart | Restart service |
systemctl reload | Reload configuration |
systemctl enable | Enable at boot |
systemctl disable | Disable at boot |
systemctl status | Show service status |
systemctl is-active | Check if active |
systemctl is-enabled | Check if enabled |
Journalctl
| Command | Description |
|---|---|
journalctl -u SERVICE | Service logs |
journalctl -f | Follow logs |
journalctl -b | Current boot logs |
journalctl --since | Time-filtered logs |
journalctl -p err | Error priority logs |
journalctl -k | Kernel messages |
systemd provides a powerful, modern init system with extensive features for service management, logging, and system administration, making it the standard for most Linux distributions.
sysctl
sysctl is a tool for examining and changing kernel parameters at runtime. It's used to modify kernel behavior without rebooting.
Basic Usage
# List all parameters
sysctl -a
# Get specific parameter
sysctl net.ipv4.ip_forward
# Set parameter (temporary)
sudo sysctl -w net.ipv4.ip_forward=1
# Load from configuration file
sudo sysctl -p /etc/sysctl.conf
Common Parameters
# Network settings
net.ipv4.ip_forward = 1 # Enable IP forwarding
net.ipv4.tcp_syncookies = 1 # SYN flood protection
net.core.somaxconn = 1024 # Connection backlog
net.ipv4.tcp_max_syn_backlog = 2048 # SYN backlog
# Memory settings
vm.swappiness = 10 # Swap preference (0-100)
vm.dirty_ratio = 15 # Dirty page threshold
vm.overcommit_memory = 1 # Memory overcommit
# File system
fs.file-max = 65536 # Max open files
fs.inotify.max_user_watches = 524288 # inotify watches
# Kernel settings
kernel.sysrq = 1 # Enable SysRq key
kernel.panic = 10 # Reboot after panic (seconds)
Persistent Configuration
# /etc/sysctl.conf or /etc/sysctl.d/99-custom.conf
net.ipv4.ip_forward = 1
vm.swappiness = 10
fs.file-max = 100000
# Apply configuration
sudo sysctl -p
Performance Tuning
# High-performance networking
net.core.rmem_max = 134217728
net.core.wmem_max = 134217728
net.ipv4.tcp_rmem = 4096 87380 67108864
net.ipv4.tcp_wmem = 4096 65536 67108864
net.ipv4.tcp_congestion_control = bbr
# Database server optimization
vm.swappiness = 1
vm.dirty_background_ratio = 5
vm.dirty_ratio = 10
sysctl provides runtime kernel tuning for optimizing system performance and behavior.
sysfs
sysfs is a virtual filesystem that exports information about kernel subsystems, hardware devices, and associated device drivers to userspace.
Overview
sysfs is mounted at /sys and provides:
- Device information
- Driver parameters
- Kernel configuration
- Power management settings
Structure
/sys/
├── block/ # Block devices
├── bus/ # Bus types (pci, usb, etc.)
├── class/ # Device classes (network, input, etc.)
├── devices/ # Device tree
├── firmware/ # Firmware information
├── fs/ # Filesystem information
├── kernel/ # Kernel parameters
├── module/ # Loaded kernel modules
└── power/ # Power management
Common Usage
# List block devices
ls /sys/block/
# Device information
cat /sys/class/net/eth0/address # MAC address
cat /sys/class/net/eth0/speed # Link speed
cat /sys/class/net/eth0/operstate # Interface state
# CPU information
cat /sys/devices/system/cpu/cpu0/cpufreq/scaling_governor
ls /sys/devices/system/cpu/cpu*/topology/
# GPU information
cat /sys/class/drm/card0/device/vendor
cat /sys/class/drm/card0/device/device
# USB devices
ls /sys/bus/usb/devices/
# Module parameters
ls /sys/module/*/parameters/
cat /sys/module/bluetooth/parameters/disable_esco
Power Management
# CPU frequency scaling
echo "performance" | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor
# Device power state
cat /sys/class/net/eth0/device/power/runtime_status
# Display brightness
echo 50 | sudo tee /sys/class/backlight/*/brightness
LED Control
# List LEDs
ls /sys/class/leds/
# Control LED
echo 1 > /sys/class/leds/led0/brightness
echo 0 > /sys/class/leds/led0/brightness
# LED trigger
echo "heartbeat" > /sys/class/leds/led0/trigger
sysfs provides a unified interface for interacting with kernel and hardware information.
Android Development
Overview
Android is an open-source operating system based on the Linux kernel, designed primarily for mobile devices. It's the world's most popular mobile platform, powering billions of devices worldwide. Android development involves creating applications using Java, Kotlin, or C++ that run on Android devices.
Key Concepts
Android Platform Architecture
Android is built on a multi-layered architecture:
- Linux Kernel: Foundation providing core system services
- Hardware Abstraction Layer (HAL): Interface between hardware and software
- Android Runtime (ART): Executes app bytecode with optimized performance
- Native C/C++ Libraries: Core system libraries (SQLite, OpenGL, etc.)
- Java API Framework: High-level APIs for app development
- System Apps: Pre-installed applications
Core Components
Android applications are built using four fundamental components:
- Activities: Single screen with a user interface
- Services: Background operations without UI
- Broadcast Receivers: Respond to system-wide broadcast announcements
- Content Providers: Manage shared app data
Development Environment
Prerequisites
- Java Development Kit (JDK): Version 8 or higher
- Android Studio: Official IDE for Android development
- Android SDK: Software development kit with tools and APIs
- Gradle: Build automation system
Installation
# Download Android Studio from https://developer.android.com/studio
# Install Android Studio and SDK through the setup wizard
# Verify installation
adb --version
Quick Start
Creating Your First App
// MainActivity.kt
package com.example.myfirstapp
import android.os.Bundle
import androidx.appcompat.app.AppCompatActivity
import android.widget.TextView
class MainActivity : AppCompatActivity() {
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
val textView: TextView = findViewById(R.id.textView)
textView.text = "Hello, Android!"
}
}
<!-- res/layout/activity_main.xml -->
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout
xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical"
android:gravity="center">
<TextView
android:id="@+id/textView"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Hello World!"
android:textSize="24sp" />
</LinearLayout>
AndroidManifest.xml
Every Android app must have a manifest file:
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.example.myfirstapp">
<application
android:allowBackup="true"
android:icon="@mipmap/ic_launcher"
android:label="@string/app_name"
android:theme="@style/Theme.AppCompat">
<activity android:name=".MainActivity">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
<!-- Permissions -->
<uses-permission android:name="android.permission.INTERNET" />
</manifest>
Android Application Components
Activities Lifecycle
class MyActivity : AppCompatActivity() {
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
// Initialize activity
setContentView(R.layout.activity_my)
}
override fun onStart() {
super.onStart()
// Activity is becoming visible
}
override fun onResume() {
super.onResume()
// Activity is interactive
}
override fun onPause() {
super.onPause()
// Activity is losing focus
}
override fun onStop() {
super.onStop()
// Activity is no longer visible
}
override fun onDestroy() {
super.onDestroy()
// Activity is being destroyed
}
}
Intents
Intents are messaging objects used to request actions from other components:
// Explicit Intent - Start specific activity
val intent = Intent(this, SecondActivity::class.java)
intent.putExtra("KEY_NAME", "value")
startActivity(intent)
// Implicit Intent - Let system find appropriate component
val browserIntent = Intent(Intent.ACTION_VIEW, Uri.parse("https://www.example.com"))
startActivity(browserIntent)
// Share content
val shareIntent = Intent().apply {
action = Intent.ACTION_SEND
putExtra(Intent.EXTRA_TEXT, "Check out this content!")
type = "text/plain"
}
startActivity(Intent.createChooser(shareIntent, "Share via"))
UI Development
Views and ViewGroups
// Programmatically create UI
val layout = LinearLayout(this).apply {
orientation = LinearLayout.VERTICAL
layoutParams = LinearLayout.LayoutParams(
LinearLayout.LayoutParams.MATCH_PARENT,
LinearLayout.LayoutParams.MATCH_PARENT
)
}
val button = Button(this).apply {
text = "Click Me"
setOnClickListener {
Toast.makeText(context, "Button clicked!", Toast.LENGTH_SHORT).show()
}
}
layout.addView(button)
setContentView(layout)
RecyclerView Example
// Adapter
class MyAdapter(private val items: List<String>) :
RecyclerView.Adapter<MyAdapter.ViewHolder>() {
class ViewHolder(view: View) : RecyclerView.ViewHolder(view) {
val textView: TextView = view.findViewById(R.id.textView)
}
override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): ViewHolder {
val view = LayoutInflater.from(parent.context)
.inflate(R.layout.item_layout, parent, false)
return ViewHolder(view)
}
override fun onBindViewHolder(holder: ViewHolder, position: Int) {
holder.textView.text = items[position]
}
override fun getItemCount() = items.size
}
// Usage in Activity
val recyclerView: RecyclerView = findViewById(R.id.recyclerView)
recyclerView.layoutManager = LinearLayoutManager(this)
recyclerView.adapter = MyAdapter(listOf("Item 1", "Item 2", "Item 3"))
Data Storage
SharedPreferences
// Save data
val sharedPref = getSharedPreferences("MyPrefs", Context.MODE_PRIVATE)
with(sharedPref.edit()) {
putString("username", "john_doe")
putInt("score", 100)
apply()
}
// Read data
val username = sharedPref.getString("username", "default")
val score = sharedPref.getInt("score", 0)
Room Database
// Entity
@Entity(tableName = "users")
data class User(
@PrimaryKey(autoGenerate = true) val id: Int = 0,
@ColumnInfo(name = "name") val name: String,
@ColumnInfo(name = "email") val email: String
)
// DAO
@Dao
interface UserDao {
@Query("SELECT * FROM users")
fun getAllUsers(): List<User>
@Insert
fun insert(user: User)
@Delete
fun delete(user: User)
}
// Database
@Database(entities = [User::class], version = 1)
abstract class AppDatabase : RoomDatabase() {
abstract fun userDao(): UserDao
}
Networking
Retrofit Example
// API Interface
interface ApiService {
@GET("users/{id}")
suspend fun getUser(@Path("id") userId: Int): User
@POST("users")
suspend fun createUser(@Body user: User): User
}
// Implementation
val retrofit = Retrofit.Builder()
.baseUrl("https://api.example.com/")
.addConverterFactory(GsonConverterFactory.create())
.build()
val apiService = retrofit.create(ApiService::class.java)
// Usage with Coroutines
lifecycleScope.launch {
try {
val user = apiService.getUser(1)
// Update UI with user data
} catch (e: Exception) {
// Handle error
}
}
Modern Android Development
Jetpack Compose
Jetpack Compose is Android's modern toolkit for building native UI:
@Composable
fun Greeting(name: String) {
Text(
text = "Hello $name!",
modifier = Modifier.padding(16.dp),
style = MaterialTheme.typography.h4
)
}
@Composable
fun Counter() {
var count by remember { mutableStateOf(0) }
Column(
modifier = Modifier.fillMaxSize(),
horizontalAlignment = Alignment.CenterHorizontally,
verticalArrangement = Arrangement.Center
) {
Text("Count: $count")
Button(onClick = { count++ }) {
Text("Increment")
}
}
}
ViewModel
class MyViewModel : ViewModel() {
private val _uiState = MutableLiveData<UiState>()
val uiState: LiveData<UiState> = _uiState
fun loadData() {
viewModelScope.launch {
try {
val data = repository.getData()
_uiState.value = UiState.Success(data)
} catch (e: Exception) {
_uiState.value = UiState.Error(e.message)
}
}
}
}
// Usage in Activity
class MainActivity : AppCompatActivity() {
private val viewModel: MyViewModel by viewModels()
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
viewModel.uiState.observe(this) { state ->
when (state) {
is UiState.Success -> updateUI(state.data)
is UiState.Error -> showError(state.message)
}
}
}
}
Testing
Unit Tests
class CalculatorTest {
@Test
fun addition_isCorrect() {
assertEquals(4, 2 + 2)
}
@Test
fun viewModel_loadsData() = runTest {
val viewModel = MyViewModel(FakeRepository())
viewModel.loadData()
val state = viewModel.uiState.value
assertTrue(state is UiState.Success)
}
}
Instrumented Tests
@RunWith(AndroidJUnit4::class)
class MainActivityTest {
@get:Rule
val activityRule = ActivityScenarioRule(MainActivity::class.java)
@Test
fun testButtonClick() {
onView(withId(R.id.button))
.perform(click())
onView(withId(R.id.textView))
.check(matches(withText("Button clicked!")))
}
}
Build Configuration
build.gradle (Module level)
plugins {
id 'com.android.application'
id 'org.jetbrains.kotlin.android'
}
android {
namespace 'com.example.myapp'
compileSdk 34
defaultConfig {
applicationId "com.example.myapp"
minSdk 24
targetSdk 34
versionCode 1
versionName "1.0"
}
buildTypes {
release {
minifyEnabled true
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
}
dependencies {
implementation 'androidx.core:core-ktx:1.12.0'
implementation 'androidx.appcompat:appcompat:1.6.1'
implementation 'com.google.android.material:material:1.11.0'
testImplementation 'junit:junit:4.13.2'
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
}
Best Practices
- Follow Material Design Guidelines: Use Material Components for consistent UI
- Handle Configuration Changes: Save state during rotation
- Use Architecture Components: ViewModel, LiveData, Room
- Implement Proper Error Handling: Never ignore exceptions
- Optimize Performance: Avoid blocking the main thread
- Test Your Code: Write unit and instrumented tests
- Follow Android Security Best Practices: Validate inputs, use encryption
- Support Multiple Screen Sizes: Use responsive layouts
- Handle Permissions Properly: Request permissions at runtime
- Keep Libraries Updated: Use latest stable versions
Resources
Documentation
Related Files
- Android Internals - Understanding Android architecture
- ADB Commands - Android Debug Bridge reference
- Development Guide - Detailed development workflow
- Android Binder - Inter-process communication mechanism
Common Issues
Gradle Sync Failed
# Clear Gradle cache
./gradlew clean
# Invalidate caches in Android Studio: File > Invalidate Caches / Restart
App Crashes on Launch
- Check Logcat for stack traces
- Verify all required permissions are declared
- Ensure ProGuard rules are correct for release builds
Memory Leaks
- Use LeakCanary for detection
- Avoid holding Activity context in long-lived objects
- Unregister listeners and callbacks
Next Steps
- Complete the Android Development Guide
- Learn ADB commands for debugging
- Study Android Internals for deeper understanding
- Build sample projects to practice
- Explore Jetpack Compose for modern UI development
Android Internals
Overview
Android is an open-source operating system primarily designed for mobile devices such as smartphones and tablets. It is based on the Linux kernel and developed by Google. Understanding Android internals is crucial for developers who want to create efficient and optimized applications or modify the operating system itself.
Key Components
1. Linux Kernel
The Linux kernel is the core of the Android operating system. It provides essential system services such as process management, memory management, security, and hardware abstraction. The kernel also includes drivers for various hardware components like display, camera, and audio.
2. Hardware Abstraction Layer (HAL)
The Hardware Abstraction Layer (HAL) defines a standard interface for hardware vendors to implement. It allows Android to communicate with the hardware-specific drivers in the Linux kernel. HAL modules are implemented as shared libraries and loaded by the Android system at runtime.
3. Android Runtime (ART)
The Android Runtime (ART) is the managed runtime used by applications and some system services on Android. ART executes the Dalvik Executable (DEX) bytecode, which is compiled from Java source code. ART includes features like ahead-of-time (AOT) compilation, just-in-time (JIT) compilation, and garbage collection to improve performance and memory management.
4. Native C/C++ Libraries
Android provides a set of native libraries written in C and C++ that are used by various components of the system. These libraries include:
- Bionic: The standard C library (libc) for Android, derived from BSD's libc.
- SurfaceFlinger: A compositing window manager that renders the display surface.
- Media Framework: Provides support for playing and recording audio and video.
- SQLite: A lightweight relational database engine used for data storage.
5. Application Framework
The Application Framework provides a set of higher-level services and APIs that developers use to build applications. Key components of the application framework include:
- Activity Manager: Manages the lifecycle of applications and activities.
- Content Providers: Manage access to structured data and provide a way to share data between applications.
- Resource Manager: Handles resources like strings, graphics, and layout files.
- Notification Manager: Allows applications to display notifications to the user.
- View System: Provides a set of UI components for building user interfaces.
6. System Applications
Android includes a set of core system applications that provide basic functionality to the user. These applications are written using the same APIs available to third-party developers. Examples of system applications include:
- Phone: Manages phone calls and contacts.
- Messages: Handles SMS and MMS messaging.
- Browser: Provides web browsing capabilities.
- Settings: Allows users to configure system settings.
Conclusion
Understanding Android internals is essential for developers who want to create high-performance applications or contribute to the Android open-source project. By familiarizing yourself with the key components of the Android operating system, you can gain a deeper insight into how Android works and how to optimize your applications for better performance and user experience.
Android Development Guide
Overview
This guide covers the complete Android development workflow, from setting up Android Studio to building, debugging, and deploying applications. It focuses on practical development practices and tools that every Android developer should know.
Android Studio Setup
Installation
Android Studio is the official IDE for Android development, built on JetBrains' IntelliJ IDEA.
# Download from https://developer.android.com/studio
# Linux installation
sudo tar -xzf android-studio-*.tar.gz -C /opt/
cd /opt/android-studio/bin
./studio.sh
# Add to PATH (optional)
echo 'export PATH=$PATH:/opt/android-studio/bin' >> ~/.bashrc
Initial Configuration
- Welcome Screen: Choose "Standard" installation
- SDK Components: Install latest Android SDK and tools
- Emulator: Install Android Emulator and system images
- Gradle: Let Android Studio manage Gradle installation
SDK Manager
# Open SDK Manager: Tools > SDK Manager
# Essential SDK Packages:
# - Android SDK Platform (latest)
# - Android SDK Build-Tools
# - Android Emulator
# - Android SDK Platform-Tools
# - Android SDK Tools
# Command-line SDK management
sdkmanager --list
sdkmanager "platform-tools" "platforms;android-34"
sdkmanager --update
AVD Manager
Create virtual devices for testing:
# Open AVD Manager: Tools > AVD Manager
# Or use command line
avdmanager create avd -n Pixel_7 -k "system-images;android-34;google_apis;x86_64"
avdmanager list avd
# Start emulator from command line
emulator -avd Pixel_7
Project Structure
Standard Android Project
MyApp/
├── app/
│ ├── src/
│ │ ├── main/
│ │ │ ├── java/com/example/myapp/
│ │ │ │ ├── MainActivity.kt
│ │ │ │ ├── models/
│ │ │ │ ├── viewmodels/
│ │ │ │ └── repositories/
│ │ │ ├── res/
│ │ │ │ ├── layout/
│ │ │ │ ├── values/
│ │ │ │ ├── drawable/
│ │ │ │ ├── mipmap/
│ │ │ │ └── menu/
│ │ │ └── AndroidManifest.xml
│ │ ├── test/ # Unit tests
│ │ └── androidTest/ # Instrumented tests
│ ├── build.gradle
│ └── proguard-rules.pro
├── gradle/
├── build.gradle
├── settings.gradle
└── gradle.properties
Key Directories
- java/: Source code (Kotlin/Java)
- res/: Resources (layouts, strings, images)
- res/layout/: XML layout files
- res/values/: Strings, colors, dimensions, styles
- res/drawable/: Images and vector graphics
- res/mipmap/: App launcher icons
- AndroidManifest.xml: App configuration and permissions
Activities
Creating an Activity
Activities represent a single screen in your app.
// MainActivity.kt
package com.example.myapp
import android.os.Bundle
import android.widget.Button
import android.widget.TextView
import androidx.appcompat.app.AppCompatActivity
class MainActivity : AppCompatActivity() {
private lateinit var textView: TextView
private lateinit var button: Button
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
// Initialize views
textView = findViewById(R.id.textView)
button = findViewById(R.id.button)
// Set click listener
button.setOnClickListener {
textView.text = "Button clicked!"
}
}
}
Activity Lifecycle
class MyActivity : AppCompatActivity() {
private val TAG = "MyActivity"
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
Log.d(TAG, "onCreate called")
setContentView(R.layout.activity_my)
// Restore saved state
savedInstanceState?.let {
val savedText = it.getString("saved_text")
textView.text = savedText
}
}
override fun onStart() {
super.onStart()
Log.d(TAG, "onStart called")
// Activity becoming visible
}
override fun onResume() {
super.onResume()
Log.d(TAG, "onResume called")
// Activity in foreground, user can interact
// Start animations, resume sensors
}
override fun onPause() {
super.onPause()
Log.d(TAG, "onPause called")
// Activity losing focus
// Pause animations, release sensors
}
override fun onStop() {
super.onStop()
Log.d(TAG, "onStop called")
// Activity no longer visible
// Release heavy resources
}
override fun onDestroy() {
super.onDestroy()
Log.d(TAG, "onDestroy called")
// Activity being destroyed
// Final cleanup
}
override fun onSaveInstanceState(outState: Bundle) {
super.onSaveInstanceState(outState)
// Save state before activity is killed
outState.putString("saved_text", textView.text.toString())
}
}
Registering Activities
<!-- AndroidManifest.xml -->
<application>
<!-- Launcher Activity -->
<activity
android:name=".MainActivity"
android:exported="true">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
<!-- Other Activities -->
<activity
android:name=".SecondActivity"
android:label="@string/second_activity_title"
android:parentActivityName=".MainActivity" />
</application>
Intents
Explicit Intents
Used to start specific components within your app:
// Start another activity
val intent = Intent(this, SecondActivity::class.java)
startActivity(intent)
// Pass data to activity
val intent = Intent(this, DetailActivity::class.java).apply {
putExtra("USER_ID", 12345)
putExtra("USERNAME", "john_doe")
putExtra("USER_DATA", userData) // Parcelable or Serializable
}
startActivity(intent)
// Receive data in target activity
class DetailActivity : AppCompatActivity() {
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
val userId = intent.getIntExtra("USER_ID", -1)
val username = intent.getStringExtra("USERNAME")
val userData = intent.getParcelableExtra<UserData>("USER_DATA")
}
}
Activity Results
// Modern approach using Activity Result API
class MainActivity : AppCompatActivity() {
private val getContent = registerForActivityResult(
ActivityResultContracts.StartActivityForResult()
) { result ->
if (result.resultCode == Activity.RESULT_OK) {
val data = result.data?.getStringExtra("RESULT_DATA")
// Handle result
}
}
private fun launchSecondActivity() {
val intent = Intent(this, SecondActivity::class.java)
getContent.launch(intent)
}
}
// Return result from activity
class SecondActivity : AppCompatActivity() {
private fun returnResult() {
val resultIntent = Intent().apply {
putExtra("RESULT_DATA", "Some result")
}
setResult(Activity.RESULT_OK, resultIntent)
finish()
}
}
Implicit Intents
Used to request actions from other apps:
// Open web page
val webpage = Uri.parse("https://www.example.com")
val intent = Intent(Intent.ACTION_VIEW, webpage)
startActivity(intent)
// Make phone call
val phoneNumber = Uri.parse("tel:1234567890")
val intent = Intent(Intent.ACTION_DIAL, phoneNumber)
startActivity(intent)
// Send email
val intent = Intent(Intent.ACTION_SENDTO).apply {
data = Uri.parse("mailto:")
putExtra(Intent.EXTRA_EMAIL, arrayOf("recipient@example.com"))
putExtra(Intent.EXTRA_SUBJECT, "Email subject")
putExtra(Intent.EXTRA_TEXT, "Email body")
}
startActivity(intent)
// Share content
val shareIntent = Intent().apply {
action = Intent.ACTION_SEND
putExtra(Intent.EXTRA_TEXT, "Check this out!")
type = "text/plain"
}
startActivity(Intent.createChooser(shareIntent, "Share via"))
// Take photo
val takePictureIntent = Intent(MediaStore.ACTION_IMAGE_CAPTURE)
if (takePictureIntent.resolveActivity(packageManager) != null) {
startActivity(takePictureIntent)
}
// Pick image from gallery
val pickPhotoIntent = Intent(Intent.ACTION_PICK,
MediaStore.Images.Media.EXTERNAL_CONTENT_URI)
startActivity(pickPhotoIntent)
Intent Filters
<!-- Declare activity can handle specific actions -->
<activity android:name=".ShareActivity">
<intent-filter>
<action android:name="android.intent.action.SEND" />
<category android:name="android.intent.category.DEFAULT" />
<data android:mimeType="text/plain" />
</intent-filter>
</activity>
<!-- Handle custom URL scheme -->
<activity android:name=".DeepLinkActivity">
<intent-filter android:autoVerify="true">
<action android:name="android.intent.action.VIEW" />
<category android:name="android.intent.category.DEFAULT" />
<category android:name="android.intent.category.BROWSABLE" />
<data
android:scheme="https"
android:host="www.example.com"
android:pathPrefix="/app" />
</intent-filter>
</activity>
Layouts
XML Layouts
LinearLayout
<!-- res/layout/activity_main.xml -->
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout
xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical"
android:padding="16dp">
<TextView
android:id="@+id/titleTextView"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:text="@string/title"
android:textSize="24sp"
android:textStyle="bold" />
<EditText
android:id="@+id/nameEditText"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_marginTop="16dp"
android:hint="@string/enter_name"
android:inputType="textPersonName" />
<Button
android:id="@+id/submitButton"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_gravity="center"
android:layout_marginTop="16dp"
android:text="@string/submit" />
</LinearLayout>
ConstraintLayout
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout
xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
android:layout_width="match_parent"
android:layout_height="match_parent">
<TextView
android:id="@+id/titleTextView"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="@string/title"
android:textSize="24sp"
app:layout_constraintTop_toTopOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintEnd_toEndOf="parent"
android:layout_marginTop="32dp" />
<ImageView
android:id="@+id/imageView"
android:layout_width="200dp"
android:layout_height="200dp"
android:src="@drawable/placeholder"
app:layout_constraintTop_toBottomOf="@id/titleTextView"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintEnd_toEndOf="parent"
android:layout_marginTop="24dp" />
<Button
android:id="@+id/button"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="@string/action"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintEnd_toEndOf="parent"
android:layout_marginBottom="32dp" />
</androidx.constraintlayout.widget.ConstraintLayout>
RelativeLayout
<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout
xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:padding="16dp">
<TextView
android:id="@+id/header"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_alignParentTop="true"
android:text="@string/header"
android:textSize="20sp" />
<Button
android:id="@+id/button"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_centerInParent="true"
android:text="@string/click_me" />
<TextView
android:id="@+id/footer"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_alignParentBottom="true"
android:text="@string/footer"
android:gravity="center" />
</RelativeLayout>
FrameLayout
<?xml version="1.0" encoding="utf-8"?>
<FrameLayout
xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="match_parent">
<!-- Background -->
<ImageView
android:layout_width="match_parent"
android:layout_height="match_parent"
android:src="@drawable/background"
android:scaleType="centerCrop" />
<!-- Foreground content -->
<LinearLayout
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_gravity="center"
android:orientation="vertical"
android:padding="32dp">
<TextView
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="@string/welcome"
android:textSize="32sp"
android:textColor="@android:color/white" />
</LinearLayout>
</FrameLayout>
View Binding
Safer alternative to findViewById:
// Enable in build.gradle
android {
buildFeatures {
viewBinding = true
}
}
// Usage in Activity
class MainActivity : AppCompatActivity() {
private lateinit var binding: ActivityMainBinding
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
binding = ActivityMainBinding.inflate(layoutInflater)
setContentView(binding.root)
// Access views directly
binding.button.setOnClickListener {
binding.textView.text = "Clicked!"
}
}
}
// Usage in Fragment
class MyFragment : Fragment() {
private var _binding: FragmentMyBinding? = null
private val binding get() = _binding!!
override fun onCreateView(
inflater: LayoutInflater,
container: ViewGroup?,
savedInstanceState: Bundle?
): View {
_binding = FragmentMyBinding.inflate(inflater, container, false)
return binding.root
}
override fun onDestroyView() {
super.onDestroyView()
_binding = null
}
}
RecyclerView
// Item layout: res/layout/item_user.xml
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout
xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:padding="16dp"
android:orientation="horizontal">
<TextView
android:id="@+id/nameTextView"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:layout_weight="1"
android:textSize="18sp" />
</LinearLayout>
// Data class
data class User(val id: Int, val name: String)
// Adapter
class UserAdapter(private val users: List<User>) :
RecyclerView.Adapter<UserAdapter.UserViewHolder>() {
class UserViewHolder(view: View) : RecyclerView.ViewHolder(view) {
val nameTextView: TextView = view.findViewById(R.id.nameTextView)
}
override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): UserViewHolder {
val view = LayoutInflater.from(parent.context)
.inflate(R.layout.item_user, parent, false)
return UserViewHolder(view)
}
override fun onBindViewHolder(holder: UserViewHolder, position: Int) {
val user = users[position]
holder.nameTextView.text = user.name
holder.itemView.setOnClickListener {
// Handle click
}
}
override fun getItemCount() = users.size
}
// Usage
val recyclerView: RecyclerView = findViewById(R.id.recyclerView)
recyclerView.layoutManager = LinearLayoutManager(this)
recyclerView.adapter = UserAdapter(userList)
Fragments
// Fragment class
class MyFragment : Fragment() {
private var _binding: FragmentMyBinding? = null
private val binding get() = _binding!!
override fun onCreateView(
inflater: LayoutInflater,
container: ViewGroup?,
savedInstanceState: Bundle?
): View {
_binding = FragmentMyBinding.inflate(inflater, container, false)
return binding.root
}
override fun onViewCreated(view: View, savedInstanceState: Bundle?) {
super.onViewCreated(view, savedInstanceState)
binding.button.setOnClickListener {
// Handle click
}
}
override fun onDestroyView() {
super.onDestroyView()
_binding = null
}
}
// Add fragment to activity
supportFragmentManager.commit {
setReorderingAllowed(true)
add(R.id.fragment_container, MyFragment())
}
// Replace fragment
supportFragmentManager.commit {
setReorderingAllowed(true)
replace(R.id.fragment_container, AnotherFragment())
addToBackStack("transaction_name")
}
// Pass arguments to fragment
val fragment = MyFragment().apply {
arguments = Bundle().apply {
putString("ARG_NAME", "value")
putInt("ARG_ID", 123)
}
}
// Retrieve arguments in fragment
val name = arguments?.getString("ARG_NAME")
val id = arguments?.getInt("ARG_ID")
Resources
Strings
<!-- res/values/strings.xml -->
<resources>
<string name="app_name">My App</string>
<string name="welcome_message">Welcome, %1$s!</string>
<string name="items_count">You have %d items</string>
<plurals name="number_of_items">
<item quantity="one">%d item</item>
<item quantity="other">%d items</item>
</plurals>
</resources>
<!-- Usage in code -->
val welcome = getString(R.string.welcome_message, "John")
val count = getString(R.string.items_count, 5)
val plural = resources.getQuantityString(R.plurals.number_of_items, count, count)
Colors
<!-- res/values/colors.xml -->
<resources>
<color name="purple_200">#FFBB86FC</color>
<color name="purple_500">#FF6200EE</color>
<color name="purple_700">#FF3700B3</color>
<color name="teal_200">#FF03DAC5</color>
<color name="black">#FF000000</color>
<color name="white">#FFFFFFFF</color>
</resources>
Dimensions
<!-- res/values/dimens.xml -->
<resources>
<dimen name="padding_small">8dp</dimen>
<dimen name="padding_medium">16dp</dimen>
<dimen name="padding_large">24dp</dimen>
<dimen name="text_size_small">12sp</dimen>
<dimen name="text_size_medium">16sp</dimen>
<dimen name="text_size_large">20sp</dimen>
</resources>
Styles and Themes
<!-- res/values/styles.xml -->
<resources>
<!-- Base application theme -->
<style name="AppTheme" parent="Theme.MaterialComponents.DayNight">
<item name="colorPrimary">@color/purple_500</item>
<item name="colorPrimaryVariant">@color/purple_700</item>
<item name="colorOnPrimary">@color/white</item>
<item name="colorSecondary">@color/teal_200</item>
</style>
<!-- Custom style -->
<style name="CustomButton" parent="Widget.MaterialComponents.Button">
<item name="android:textColor">@color/white</item>
<item name="backgroundTint">@color/purple_500</item>
<item name="cornerRadius">8dp</item>
</style>
</resources>
Debugging
Logcat
import android.util.Log
class MainActivity : AppCompatActivity() {
companion object {
private const val TAG = "MainActivity"
}
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
// Different log levels
Log.v(TAG, "Verbose message") // Verbose
Log.d(TAG, "Debug message") // Debug
Log.i(TAG, "Info message") // Info
Log.w(TAG, "Warning message") // Warning
Log.e(TAG, "Error message") // Error
// Log with exception
try {
// Code that might throw
} catch (e: Exception) {
Log.e(TAG, "Error occurred", e)
}
}
}
Breakpoints
- Click left margin in code editor to set breakpoint
- Run app in Debug mode (Shift + F9)
- Use Debug panel to step through code:
- Step Over (F8)
- Step Into (F7)
- Step Out (Shift + F8)
- Resume (F9)
Layout Inspector
Tools > Layout Inspector
- View hierarchy in real-time
- Inspect view properties
- Debug rendering issues
Build and Deploy
Building APK
# Debug build
./gradlew assembleDebug
# Release build
./gradlew assembleRelease
# Install debug APK
./gradlew installDebug
# APK location
# Debug: app/build/outputs/apk/debug/app-debug.apk
# Release: app/build/outputs/apk/release/app-release.apk
Signing Configuration
// app/build.gradle
android {
signingConfigs {
release {
storeFile file("release-keystore.jks")
storePassword "your_store_password"
keyAlias "your_key_alias"
keyPassword "your_key_password"
}
}
buildTypes {
release {
signingConfig signingConfigs.release
minifyEnabled true
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'),
'proguard-rules.pro'
}
}
}
ProGuard Rules
# app/proguard-rules.pro
# Keep model classes
-keep class com.example.app.models.** { *; }
# Keep Parcelable implementations
-keep class * implements android.os.Parcelable {
public static final ** CREATOR;
}
# Gson
-keepattributes Signature
-keepattributes *Annotation*
-keep class com.google.gson.** { *; }
# Retrofit
-keepattributes Signature
-keepattributes Exceptions
-keep class retrofit2.** { *; }
Best Practices
- Use ConstraintLayout for complex, flat hierarchies
- Implement View Binding instead of findViewById
- Follow Material Design guidelines
- Use string resources instead of hardcoded strings
- Handle configuration changes properly
- Use Fragments for reusable UI components
- Implement proper error handling and user feedback
- Test on multiple devices and screen sizes
- Optimize layouts to reduce overdraw
- Use Android Architecture Components (ViewModel, LiveData, Room)
Related Resources
Android Binder
Binder is Android's inter-process communication (IPC) mechanism. It's a custom implementation allowing processes to communicate efficiently and securely.
Overview
Binder enables:
- Cross-process method invocation
- Object reference passing
- Security via UID/PID checking
- Death notification
Architecture
Client Process Binder Driver Server Process
│ │ │
│──Service Request──────>│ │
│ │──Forward Request────>│
│ │<──Response───────────│
│<──Return Result────────│ │
AIDL (Android Interface Definition Language)
// ICalculator.aidl
package com.example;
interface ICalculator {
int add(int a, int b);
int subtract(int a, int b);
}
Service Implementation
// CalculatorService.java
public class CalculatorService extends Service {
private final ICalculator.Stub binder = new ICalculator.Stub() {
@Override
public int add(int a, int b) {
return a + b;
}
@Override
public int subtract(int a, int b) {
return a - b;
}
};
@Override
public IBinder onBind(Intent intent) {
return binder;
}
}
Client Usage
// Client code
ServiceConnection connection = new ServiceConnection() {
public void onServiceConnected(ComponentName name, IBinder service) {
ICalculator calculator = ICalculator.Stub.asInterface(service);
int result = calculator.add(5, 3); // Result: 8
}
public void onServiceDisconnected(ComponentName name) {
// Handle disconnection
}
};
bindService(intent, connection, Context.BIND_AUTO_CREATE);
Key Features
- Security: Permission checking at IPC boundaries
- Reference Counting: Automatic resource management
- Death Recipients: Notification when remote process dies
- Asynchronous: Non-blocking calls with oneway keyword
Binder is fundamental to Android's architecture, enabling system services and app communication.
Android Debug Bridge (ADB)
Overview
Android Debug Bridge (ADB) is a versatile command-line tool that lets you communicate with an Android device. ADB facilitates a variety of device actions, such as installing and debugging apps, and it provides access to a Unix shell that you can use to run various commands on a device.
ADB is included in the Android SDK Platform Tools package and can be used with physical devices connected via USB or with emulators.
Installation
Linux/Mac
# Install via Android SDK Platform Tools
# Or use package manager
sudo apt install adb # Ubuntu/Debian
brew install android-platform-tools # macOS
# Verify installation
adb version
Windows
# Download Android SDK Platform Tools from:
# https://developer.android.com/studio/releases/platform-tools
# Add to PATH and verify
adb version
Setup and Connection
Enable Developer Options
- Go to Settings > About Phone
- Tap "Build Number" 7 times
- Go back to Settings > Developer Options
- Enable "USB Debugging"
Connect Device via USB
# List connected devices
adb devices
# Output example:
# List of devices attached
# 1234567890ABCDEF device
# emulator-5554 device
# Connect to specific device
adb -s 1234567890ABCDEF shell
Connect Device via WiFi
# Connect device via USB first, then:
# Get device IP address
adb shell ip addr show wlan0
# Enable TCP/IP mode on port 5555
adb tcpip 5555
# Disconnect USB and connect via WiFi
adb connect 192.168.1.100:5555
# Verify connection
adb devices
# Disconnect
adb disconnect 192.168.1.100:5555
# Return to USB mode
adb usb
Basic Commands
Device Management
# List all connected devices
adb devices -l
# Start ADB server
adb start-server
# Kill ADB server
adb kill-server
# Restart ADB server
adb kill-server && adb start-server
# Wait for device to be connected
adb wait-for-device
# Get device state
adb get-state
# Get device serial number
adb get-serialno
Device Information
# Get device model
adb shell getprop ro.product.model
# Get Android version
adb shell getprop ro.build.version.release
# Get device manufacturer
adb shell getprop ro.product.manufacturer
# Get device serial number
adb shell getprop ro.serialno
# Get device resolution
adb shell wm size
# Get device density
adb shell wm density
# Display all properties
adb shell getprop
# Get battery status
adb shell dumpsys battery
# Get CPU information
adb shell cat /proc/cpuinfo
# Get memory information
adb shell cat /proc/meminfo
App Management
Installing and Uninstalling Apps
# Install APK
adb install app.apk
# Install APK to specific location
adb install -s /sdcard/app.apk
# Reinstall existing app (keep data)
adb install -r app.apk
# Install APK to SD card
adb install -s app.apk
# Uninstall app
adb uninstall com.example.app
# Uninstall app but keep data
adb uninstall -k com.example.app
Package Information
# List all packages
adb shell pm list packages
# List third-party packages
adb shell pm list packages -3
# List system packages
adb shell pm list packages -s
# Search for specific package
adb shell pm list packages | grep keyword
# Get path of installed package
adb shell pm path com.example.app
# Get app information
adb shell dumpsys package com.example.app
# Clear app data
adb shell pm clear com.example.app
# Enable/Disable app
adb shell pm enable com.example.app
adb shell pm disable com.example.app
Running Apps
# Start an activity
adb shell am start -n com.example.app/.MainActivity
# Start activity with data
adb shell am start -a android.intent.action.VIEW -d "https://example.com"
# Start service
adb shell am startservice com.example.app/.MyService
# Broadcast intent
adb shell am broadcast -a android.intent.action.BOOT_COMPLETED
# Force stop app
adb shell am force-stop com.example.app
# Kill app process
adb shell am kill com.example.app
File Operations
Copying Files
# Copy file from device to computer
adb pull /sdcard/file.txt ~/Desktop/
# Copy file from computer to device
adb push ~/Desktop/file.txt /sdcard/
# Copy directory recursively
adb pull /sdcard/DCIM/ ~/Pictures/
# Copy with progress display
adb pull /sdcard/large_file.mp4 .
# Push multiple files
adb push file1.txt file2.txt /sdcard/
File System Navigation
# Access device shell
adb shell
# Navigate directories (once in shell)
cd /sdcard
ls -la
pwd
# Create directory
adb shell mkdir /sdcard/NewFolder
# Remove file
adb shell rm /sdcard/file.txt
# Remove directory
adb shell rm -r /sdcard/OldFolder
# Change file permissions
adb shell chmod 777 /sdcard/file.txt
# View file contents
adb shell cat /sdcard/file.txt
# Search for files
adb shell find /sdcard -name "*.txt"
Logging and Debugging
Logcat
# View all logs
adb logcat
# Clear log buffer
adb logcat -c
# View logs with specific priority
adb logcat *:E # Error
adb logcat *:W # Warning
adb logcat *:I # Info
adb logcat *:D # Debug
adb logcat *:V # Verbose
# Filter by tag
adb logcat -s MyApp
# Filter by multiple tags
adb logcat -s MyApp:D ActivityManager:W
# Save logs to file
adb logcat > logfile.txt
# View logs with timestamp
adb logcat -v time
# View logs in different formats
adb logcat -v brief
adb logcat -v process
adb logcat -v tag
adb logcat -v thread
adb logcat -v raw
adb logcat -v long
# Filter using grep
adb logcat | grep "keyword"
# View specific buffer
adb logcat -b radio # Radio/telephony logs
adb logcat -b events # Event logs
adb logcat -b main # Main application logs
adb logcat -b system # System logs
adb logcat -b crash # Crash logs
# Continuous monitoring with color
adb logcat -v color
Bug Reports
# Generate bug report
adb bugreport
# Save bug report to file
adb bugreport > bugreport.txt
# Generate zipped bug report (Android 7.0+)
adb bugreport bugreport.zip
Screen Control
Screenshots and Screen Recording
# Take screenshot
adb shell screencap /sdcard/screenshot.png
adb pull /sdcard/screenshot.png
# Take screenshot (one command)
adb exec-out screencap -p > screenshot.png
# Record screen (Ctrl+C to stop)
adb shell screenrecord /sdcard/demo.mp4
# Record with time limit (max 180 seconds)
adb shell screenrecord --time-limit 30 /sdcard/demo.mp4
# Record with specific size
adb shell screenrecord --size 1280x720 /sdcard/demo.mp4
# Record with specific bitrate
adb shell screenrecord --bit-rate 6000000 /sdcard/demo.mp4
# Pull recorded video
adb pull /sdcard/demo.mp4
Screen Input
# Tap at coordinates (x, y)
adb shell input tap 500 1000
# Swipe from (x1,y1) to (x2,y2) over duration ms
adb shell input swipe 500 1000 500 200 300
# Type text
adb shell input text "Hello%sWorld" # %s represents space
# Press key
adb shell input keyevent KEYCODE_HOME
adb shell input keyevent KEYCODE_BACK
adb shell input keyevent KEYCODE_MENU
adb shell input keyevent 3 # Home key (key code)
# Common key codes
# KEYCODE_HOME = 3
# KEYCODE_BACK = 4
# KEYCODE_MENU = 82
# KEYCODE_POWER = 26
# KEYCODE_VOLUME_UP = 24
# KEYCODE_VOLUME_DOWN = 25
System Control
Power Management
# Reboot device
adb reboot
# Reboot to recovery mode
adb reboot recovery
# Reboot to bootloader
adb reboot bootloader
# Shutdown device (requires root)
adb shell reboot -p
# Wake up screen
adb shell input keyevent KEYCODE_WAKEUP
# Sleep screen
adb shell input keyevent KEYCODE_SLEEP
Network
# Check WiFi status
adb shell dumpsys wifi
# Enable WiFi
adb shell svc wifi enable
# Disable WiFi
adb shell svc wifi disable
# Check network connectivity
adb shell ping -c 4 google.com
# Get IP address
adb shell ip addr show wlan0
# Enable/Disable data
adb shell svc data enable
adb shell svc data disable
Settings
# Get setting value
adb shell settings get system screen_brightness
# Set setting value
adb shell settings put system screen_brightness 100
# Common settings namespaces:
# - system: User preferences
# - secure: Secure system settings
# - global: Device-wide settings
# Enable airplane mode
adb shell settings put global airplane_mode_on 1
adb shell am broadcast -a android.intent.action.AIRPLANE_MODE
# Disable animations (for testing)
adb shell settings put global window_animation_scale 0
adb shell settings put global transition_animation_scale 0
adb shell settings put global animator_duration_scale 0
Advanced Commands
Dumpsys
# Get system information
adb shell dumpsys
# Battery information
adb shell dumpsys battery
# Memory usage
adb shell dumpsys meminfo
adb shell dumpsys meminfo com.example.app
# CPU usage
adb shell dumpsys cpuinfo
# Display information
adb shell dumpsys display
# Activity information
adb shell dumpsys activity
# Current activity
adb shell dumpsys activity activities | grep mResumedActivity
# Package information
adb shell dumpsys package com.example.app
# Window information
adb shell dumpsys window
Performance Monitoring
# Monitor CPU usage
adb shell top
# Monitor specific process
adb shell top | grep com.example.app
# Get process list
adb shell ps
# Get process info by name
adb shell ps | grep com.example.app
# Memory stats
adb shell procrank
# Disk usage
adb shell df
# Network statistics
adb shell netstat
Database Operations
# Access app database (requires root or debuggable app)
adb shell run-as com.example.app
# Navigate to database directory
cd /data/data/com.example.app/databases/
# Pull database
adb exec-out run-as com.example.app cat databases/mydb.db > mydb.db
# Query database using sqlite3
adb shell "run-as com.example.app sqlite3 databases/mydb.db 'SELECT * FROM users;'"
Testing and Automation
Monkey Testing
# Generate random events
adb shell monkey -p com.example.app 1000
# Monkey with specific event types
adb shell monkey -p com.example.app --pct-touch 70 --pct-motion 30 1000
# Monkey with seed (reproducible)
adb shell monkey -p com.example.app -s 100 1000
# Throttle events (delay in ms)
adb shell monkey -p com.example.app --throttle 500 1000
# Ignore crashes and continue
adb shell monkey -p com.example.app --ignore-crashes 1000
UI Automator
# Dump UI hierarchy
adb shell uiautomator dump
# Pull UI hierarchy XML
adb pull /sdcard/window_dump.xml
# Run UI Automator test
adb shell uiautomator runtest UiAutomatorTest.jar -c com.example.test.MyTest
Scripting with ADB
Batch Operations
#!/bin/bash
# Install app on all connected devices
for device in $(adb devices | grep -v "List" | awk '{print $1}'); do
echo "Installing on device: $device"
adb -s $device install app.apk
done
# Clear app data on all devices
for device in $(adb devices | grep -v "List" | awk '{print $1}'); do
echo "Clearing data on device: $device"
adb -s $device shell pm clear com.example.app
done
Automated Screenshot Script
#!/bin/bash
# Take screenshot and save with timestamp
timestamp=$(date +"%Y%m%d_%H%M%S")
filename="screenshot_${timestamp}.png"
adb exec-out screencap -p > "$filename"
echo "Screenshot saved: $filename"
Log Filtering Script
#!/bin/bash
# Monitor logs for specific package
package="com.example.app"
adb logcat | grep --line-buffered "$package" | while read line; do
echo "[$(date +"%H:%M:%S")] $line"
done
Troubleshooting
Common Issues
Device not detected:
# Check USB connection
lsusb # Linux
system_profiler SPUSBDataType # macOS
# Restart ADB
adb kill-server
adb start-server
# Check device authorization
# Accept the authorization prompt on device
Permission denied:
# Check USB debugging is enabled
# Revoke USB debugging authorizations and reconnect
# Settings > Developer Options > Revoke USB Debugging Authorizations
# Linux: Add udev rules
sudo vim /etc/udev/rules.d/51-android.rules
# Add: SUBSYSTEM=="usb", ATTR{idVendor}=="18d1", MODE="0666", GROUP="plugdev"
sudo udevadm control --reload-rules
Multiple devices:
# Specify device with -s flag
adb -s 1234567890ABCDEF shell
# Or use -d for physical device, -e for emulator
adb -d shell # Physical device
adb -e shell # Emulator
Best Practices
- Always specify device with
-swhen multiple devices are connected - Use
adb wait-for-devicein scripts before commands - Clear logcat before testing:
adb logcat -c - Use appropriate log levels to reduce noise
- Save important logs to files for later analysis
- Be careful with
rmcommands - there's no undo - Test commands on emulator before using on physical device
- Keep ADB updated with latest platform tools
- Use
adb shellfor interactive sessions, direct commands for scripts - Always pull important data before performing system changes
Security Considerations
- Disable USB debugging when not in development
- Be cautious when connecting to devices over WiFi
- Don't leave ADB over TCP/IP enabled on public networks
- Review USB debugging authorization requests carefully
- Use secure, trusted computers for ADB connections
- Never share bug reports publicly without reviewing contents first
References
Quick Reference Card
# Connection
adb devices # List devices
adb connect IP:5555 # Connect via WiFi
# Apps
adb install app.apk # Install app
adb uninstall package.name # Uninstall app
adb shell pm list packages # List packages
# Files
adb push local remote # Upload file
adb pull remote local # Download file
# Shell
adb shell # Interactive shell
adb shell command # Run single command
# Logs
adb logcat # View logs
adb logcat -c # Clear logs
# Screen
adb shell screencap /sdcard/s.png # Screenshot
adb shell screenrecord /sdcard/v.mp4 # Record screen
# System
adb reboot # Reboot device
adb shell dumpsys battery # Battery info
Data Structures
Overview
A data structure is a specialized format for organizing, processing, retrieving, and storing data. Different data structures are suited for different kinds of applications, and some are highly specialized for specific tasks. Understanding data structures is fundamental to writing efficient algorithms and building scalable software systems.
Why Data Structures Matter
- Efficiency: Right data structure can dramatically improve performance
- Organization: Logical way to organize and manage data
- Reusability: Common patterns for solving problems
- Abstraction: Hide implementation details
- Optimization: Trade-offs between time and space complexity
Classification of Data Structures
Linear Data Structures
Elements are arranged in sequential order:
- Arrays
- Linked Lists
- Stacks
- Queues
Non-Linear Data Structures
Elements are arranged hierarchically or in a network:
- Trees
- Graphs
- Tries
- Hash Tables
Static vs Dynamic
- Static: Fixed size (arrays)
- Dynamic: Size can change (linked lists, dynamic arrays)
Core Data Structures
1. Arrays
Contiguous memory locations storing elements of the same type.
# Array operations
arr = [1, 2, 3, 4, 5]
# Access - $O(1)$
element = arr[2] # 3
# Insert at end - $O(1)$ amortized
arr.append(6)
# Insert at position - $O(n)$
arr.insert(2, 10)
# Delete - $O(n)$
arr.remove(10)
# Search - $O(n)$
if 4 in arr:
print("Found")
# 2D Array
matrix = [
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]
]
# Access element
value = matrix[1][2] # 6
Time Complexity:
- Access: $O(1)$
- Search: $O(n)$
- Insert: $O(n)$
- Delete: $O(n)$
Space Complexity: $O(n)$
See: Arrays
2. Linked Lists
Nodes connected via pointers, allowing efficient insertion/deletion.
class Node:
def __init__(self, data):
self.data = data
self.next = None
class LinkedList:
def __init__(self):
self.head = None
def insert_at_beginning(self, data):
new_node = Node(data)
new_node.next = self.head
self.head = new_node
def insert_at_end(self, data):
new_node = Node(data)
if not self.head:
self.head = new_node
return
current = self.head
while current.next:
current = current.next
current.next = new_node
def delete(self, key):
current = self.head
# Delete head
if current and current.data == key:
self.head = current.next
return
# Delete other node
prev = None
while current and current.data != key:
prev = current
current = current.next
if current:
prev.next = current.next
def search(self, key):
current = self.head
while current:
if current.data == key:
return True
current = current.next
return False
def display(self):
elements = []
current = self.head
while current:
elements.append(current.data)
current = current.next
return elements
Types:
- Singly Linked List
- Doubly Linked List
- Circular Linked List
Time Complexity:
- Access: $O(n)$
- Search: $O(n)$
- Insert at beginning: $O(1)$
- Insert at end: $O(n)$ or $O(1)$ with tail pointer
- Delete: $O(n)$
See: Linked Lists
3. Stacks
LIFO (Last In, First Out) structure.
class Stack:
def __init__(self):
self.items = []
def push(self, item):
"""Add item to top - $O(1)$"""
self.items.append(item)
def pop(self):
"""Remove and return top item - $O(1)$"""
if not self.is_empty():
return self.items.pop()
raise IndexError("Stack is empty")
def peek(self):
"""Return top item without removing - $O(1)$"""
if not self.is_empty():
return self.items[-1]
raise IndexError("Stack is empty")
def is_empty(self):
"""Check if stack is empty - $O(1)$"""
return len(self.items) == 0
def size(self):
"""Return number of items - $O(1)$"""
return len(self.items)
# Usage
stack = Stack()
stack.push(1)
stack.push(2)
stack.push(3)
print(stack.pop()) # 3
print(stack.peek()) # 2
# Applications
def is_balanced(expression):
"""Check if parentheses are balanced"""
stack = []
opening = "([{"
closing = ")]}"
pairs = {"(": ")", "[": "]", "{": "}"}
for char in expression:
if char in opening:
stack.append(char)
elif char in closing:
if not stack or pairs[stack.pop()] != char:
return False
return len(stack) == 0
# Reverse string using stack
def reverse_string(s):
stack = list(s)
return ''.join(stack[::-1])
Applications:
- Function call stack
- Undo/Redo operations
- Expression evaluation
- Backtracking algorithms
- Browser history
See: Stacks
4. Queues
FIFO (First In, First Out) structure.
from collections import deque
class Queue:
def __init__(self):
self.items = deque()
def enqueue(self, item):
"""Add item to rear - $O(1)$"""
self.items.append(item)
def dequeue(self):
"""Remove and return front item - $O(1)$"""
if not self.is_empty():
return self.items.popleft()
raise IndexError("Queue is empty")
def front(self):
"""Return front item - $O(1)$"""
if not self.is_empty():
return self.items[0]
raise IndexError("Queue is empty")
def is_empty(self):
"""Check if queue is empty - $O(1)$"""
return len(self.items) == 0
def size(self):
"""Return number of items - $O(1)$"""
return len(self.items)
# Priority Queue
import heapq
class PriorityQueue:
def __init__(self):
self.heap = []
def push(self, item, priority):
"""Add item with priority - $O(\log n)$"""
heapq.heappush(self.heap, (priority, item))
def pop(self):
"""Remove and return highest priority item - $O(\log n)$"""
if self.heap:
return heapq.heappop(self.heap)[1]
raise IndexError("Queue is empty")
# Circular Queue
class CircularQueue:
def __init__(self, size):
self.size = size
self.queue = [None] * size
self.front = self.rear = -1
def enqueue(self, item):
if (self.rear + 1) % self.size == self.front:
raise Exception("Queue is full")
if self.front == -1:
self.front = 0
self.rear = (self.rear + 1) % self.size
self.queue[self.rear] = item
def dequeue(self):
if self.front == -1:
raise Exception("Queue is empty")
item = self.queue[self.front]
if self.front == self.rear:
self.front = self.rear = -1
else:
self.front = (self.front + 1) % self.size
return item
Types:
- Simple Queue
- Circular Queue
- Priority Queue
- Double-Ended Queue (Deque)
Applications:
- Task scheduling
- BFS traversal
- Print queue
- Buffer management
- Async processing
See: Queues
5. Hash Tables
Key-value pairs with $O(1)$ average-case operations.
class HashTable:
def __init__(self, size=10):
self.size = size
self.table = [[] for _ in range(size)]
def _hash(self, key):
"""Hash function - $O(1)$"""
return hash(key) % self.size
def insert(self, key, value):
"""Insert key-value pair - $O(1)$ average"""
index = self._hash(key)
# Update if key exists
for i, (k, v) in enumerate(self.table[index]):
if k == key:
self.table[index][i] = (key, value)
return
# Insert new key-value
self.table[index].append((key, value))
def get(self, key):
"""Get value by key - $O(1)$ average"""
index = self._hash(key)
for k, v in self.table[index]:
if k == key:
return v
raise KeyError(f"Key '{key}' not found")
def delete(self, key):
"""Delete key-value pair - $O(1)$ average"""
index = self._hash(key)
for i, (k, v) in enumerate(self.table[index]):
if k == key:
self.table[index].pop(i)
return
raise KeyError(f"Key '{key}' not found")
def contains(self, key):
"""Check if key exists - $O(1)$ average"""
try:
self.get(key)
return True
except KeyError:
return False
# Python dict is a hash table
hash_map = {}
hash_map["name"] = "John"
hash_map["age"] = 30
# Counter using hash table
from collections import Counter
text = "hello world"
char_count = Counter(text)
Collision Resolution:
- Chaining (linked lists)
- Open addressing (linear probing, quadratic probing, double hashing)
Time Complexity:
- Average: $O(1)$ for insert, delete, search
- Worst: $O(n)$ with many collisions
See: Hash Tables
Advanced Data Structures
6. Trees
Hierarchical structure with nodes connected by edges.
class TreeNode:
def __init__(self, value):
self.value = value
self.left = None
self.right = None
class BinarySearchTree:
def __init__(self):
self.root = None
def insert(self, value):
"""Insert value - $O(\log n)$ average, $O(n)$ worst"""
if not self.root:
self.root = TreeNode(value)
else:
self._insert_recursive(self.root, value)
def _insert_recursive(self, node, value):
if value < node.value:
if node.left is None:
node.left = TreeNode(value)
else:
self._insert_recursive(node.left, value)
else:
if node.right is None:
node.right = TreeNode(value)
else:
self._insert_recursive(node.right, value)
def search(self, value):
"""Search for value - $O(\log n)$ average"""
return self._search_recursive(self.root, value)
def _search_recursive(self, node, value):
if node is None or node.value == value:
return node
if value < node.value:
return self._search_recursive(node.left, value)
return self._search_recursive(node.right, value)
def inorder_traversal(self, node, result=None):
"""Inorder: Left -> Root -> Right"""
if result is None:
result = []
if node:
self.inorder_traversal(node.left, result)
result.append(node.value)
self.inorder_traversal(node.right, result)
return result
Types:
- Binary Tree
- Binary Search Tree
- AVL Tree (self-balancing)
- Red-Black Tree
- B-Tree
- Heap
See: Trees documentation in algorithms
7. Graphs
Network of nodes (vertices) connected by edges.
# Adjacency List representation
class Graph:
def __init__(self):
self.graph = {}
def add_vertex(self, vertex):
"""Add vertex - $O(1)$"""
if vertex not in self.graph:
self.graph[vertex] = []
def add_edge(self, v1, v2):
"""Add edge - $O(1)$"""
if v1 in self.graph and v2 in self.graph:
self.graph[v1].append(v2)
self.graph[v2].append(v1) # For undirected graph
def bfs(self, start):
"""Breadth-First Search - $O(V + E)$"""
visited = set()
queue = [start]
result = []
while queue:
vertex = queue.pop(0)
if vertex not in visited:
visited.add(vertex)
result.append(vertex)
queue.extend(self.graph[vertex])
return result
def dfs(self, start, visited=None):
"""Depth-First Search - $O(V + E)$"""
if visited is None:
visited = set()
visited.add(start)
result = [start]
for neighbor in self.graph[start]:
if neighbor not in visited:
result.extend(self.dfs(neighbor, visited))
return result
# Adjacency Matrix representation
class GraphMatrix:
def __init__(self, num_vertices):
self.num_vertices = num_vertices
self.matrix = [[0] * num_vertices for _ in range(num_vertices)]
def add_edge(self, v1, v2, weight=1):
"""Add edge with optional weight"""
self.matrix[v1][v2] = weight
self.matrix[v2][v1] = weight # For undirected graph
Types:
- Directed/Undirected
- Weighted/Unweighted
- Cyclic/Acyclic
- Connected/Disconnected
See: Graph algorithms
Choosing the Right Data Structure
Array vs Linked List
Use Array when:
- Need random access
- Size is known and fixed
- Memory is contiguous
- Cache performance matters
Use Linked List when:
- Frequent insertions/deletions
- Size is unknown
- Don't need random access
- Memory fragmentation is acceptable
Stack vs Queue
Use Stack for:
- LIFO operations
- Recursion simulation
- Undo/redo functionality
- Expression evaluation
Use Queue for:
- FIFO operations
- Scheduling
- BFS traversal
- Resource sharing
Hash Table vs Tree
Use Hash Table when:
- Need $O(1)$ lookup
- Order doesn't matter
- No range queries needed
- Keys are hashable
Use Tree when:
- Need sorted order
- Range queries required
- Prefix searches (Trie)
- Hierarchical data
Performance Comparison
| Operation | Array | Linked List | Stack | Queue | Hash Table | BST |
|---|---|---|---|---|---|---|
| Access | $O(1)$ | $O(n)$ | $O(n)$ | $O(n)$ | - | $O(\log n)$ |
| Search | $O(n)$ | $O(n)$ | $O(n)$ | $O(n)$ | $O(1)$* | $O(\log n)$ |
| Insert | $O(n)$ | $O(1)$** | $O(1)$ | $O(1)$ | $O(1)$* | $O(\log n)$ |
| Delete | $O(n)$ | $O(1)$** | $O(1)$ | $O(1)$ | $O(1)$* | $O(\log n)$ |
* Average case, ** At beginning/with reference
Common Operations
Traversal Patterns
# Array traversal
for i in range(len(arr)):
process(arr[i])
# Linked list traversal
current = head
while current:
process(current.data)
current = current.next
# Tree traversal (recursion)
def traverse_tree(node):
if node:
traverse_tree(node.left)
process(node.value)
traverse_tree(node.right)
# Graph traversal (BFS)
def bfs(graph, start):
visited = set()
queue = [start]
while queue:
vertex = queue.pop(0)
if vertex not in visited:
visited.add(vertex)
process(vertex)
queue.extend(graph[vertex])
Searching Patterns
# Linear search - $O(n)$
def linear_search(arr, target):
for i, val in enumerate(arr):
if val == target:
return i
return -1
# Binary search - $O(\log n)$
def binary_search(arr, target):
left, right = 0, len(arr) - 1
while left <= right:
mid = (left + right) // 2
if arr[mid] == target:
return mid
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return -1
# Hash table search - $O(1)$
def hash_search(hash_table, key):
return hash_table.get(key)
Real-World Applications
Arrays
- Database tables
- Image processing (pixel arrays)
- Dynamic programming tables
- Buffer implementation
Linked Lists
- Music playlists
- Browser history (doubly linked)
- Undo functionality
- Memory management (free lists)
Stacks
- Function call management
- Expression evaluation
- Backtracking (maze solving)
- Browser back button
Queues
- Print spooling
- CPU scheduling
- BFS traversal
- Message queues (async)
Hash Tables
- Database indexing
- Caching
- Symbol tables (compilers)
- Spell checkers
Trees
- File systems
- DOM (HTML)
- Decision trees (AI)
- Database indexing (B-trees)
Graphs
- Social networks
- Maps and navigation
- Network routing
- Recommendation systems
Interview Preparation
Essential Topics
-
Arrays and Strings
- Two pointers
- Sliding window
- Prefix sums
-
Linked Lists
- Reverse list
- Detect cycle
- Merge lists
-
Stacks and Queues
- Valid parentheses
- Min/max stack
- Implement queue with stacks
-
Trees
- Traversals
- Height/depth
- Lowest common ancestor
-
Graphs
- BFS/DFS
- Cycle detection
- Shortest path
-
Hash Tables
- Two sum
- Group anagrams
- LRU cache
Common Patterns
- Two Pointers: Array problems
- Fast/Slow Pointers: Linked list cycles
- Sliding Window: Subarray problems
- BFS/DFS: Tree/graph traversal
- Backtracking: Combinatorial problems
- Dynamic Programming: Optimization problems
Available Resources
Explore detailed guides for specific data structures:
- Arrays - Array operations and techniques
- Linked Lists - Singly, doubly, circular lists
- Stacks - Stack implementation and applications
- Queues - Queue types and use cases
- Hash Tables - Hashing and collision resolution
- Trees - Binary trees, BST, AVL, traversals
- Graphs - Graph representations, traversal, algorithms
- Heaps - Min heaps, max heaps, priority queues
- Tries - Prefix trees, autocomplete, string matching
Related algorithm topics:
Best Practices
- Choose appropriately: Match data structure to problem
- Consider trade-offs: Time vs space complexity
- Test edge cases: Empty, single element, duplicates
- Optimize: Start simple, then optimize
- Document: Comment complex logic
- Practice: Regular coding practice
- Learn patterns: Recognize common patterns
- Understand internals: Know how they work
Next Steps
- Master the fundamental structures (array, linked list, stack, queue)
- Practice implementing each structure from scratch
- Solve problems using each data structure
- Learn when to use each structure
- Study advanced structures (trees, graphs, tries)
- Practice on coding platforms (LeetCode, HackerRank)
- Review time/space complexity for all operations
- Work on real-world projects using these structures
Remember: Understanding data structures is essential for writing efficient code and succeeding in technical interviews. Focus on understanding the concepts, not just memorizing implementations.
Arrays
Overview
An array is a fundamental data structure that stores elements of the same type in contiguous memory locations. Arrays provide fast, constant-time access to elements using an index, making them one of the most commonly used data structures in programming.
Key Concepts
Characteristics
- Fixed Size: Most arrays have a fixed size determined at creation
- Contiguous Memory: Elements stored sequentially in memory
- Index-Based: Access elements using zero-based indexing
- Homogeneous: All elements must be of the same type
- Fast Access: $O(1)$ time complexity for accessing any element
Memory Layout
Index: 0 1 2 3 4
Array: | 10 | 20 | 30 | 40 | 50 |
Address: 1000 1004 1008 1012 1016 (for 4-byte integers)
Time Complexity
| Operation | Time Complexity |
|---|---|
| Access | $O(1)$ |
| Search | $O(n)$ |
| Insert (at end) | $O(1)$ amortized* |
| Insert (at position) | $O(n)$ |
| Delete (at end) | $O(1)$ |
| Delete (at position) | $O(n)$ |
*For dynamic arrays like Python lists or C++ vectors
Code Examples
Python
# Creating arrays
arr = [1, 2, 3, 4, 5]
arr_zeros = [0] * 10 # [0, 0, 0, ..., 0] (10 elements)
# Accessing elements
first = arr[0] # 1
last = arr[-1] # 5 (negative indexing)
# Modifying elements
arr[2] = 100 # [1, 2, 100, 4, 5]
# Slicing
sub = arr[1:4] # [2, 100, 4]
reversed_arr = arr[::-1] # [5, 4, 100, 2, 1]
# Common operations
arr.append(6) # Add to end: [1, 2, 100, 4, 5, 6]
arr.insert(2, 99) # Insert at index: [1, 2, 99, 100, 4, 5, 6]
arr.pop() # Remove last: returns 6
arr.remove(99) # Remove first occurrence of 99
length = len(arr) # Get length
# Iteration
for element in arr:
print(element)
for index, element in enumerate(arr):
print(f"Index {index}: {element}")
# List comprehension
squares = [x**2 for x in range(10)] # [0, 1, 4, 9, 16, ..., 81]
evens = [x for x in arr if x % 2 == 0] # Filter even numbers
# 2D Arrays (Matrix)
matrix = [
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]
]
element = matrix[1][2] # Access: 6
JavaScript
// Creating arrays
let arr = [1, 2, 3, 4, 5];
let arr2 = new Array(10); // Array with 10 undefined elements
let arr3 = Array.from({length: 5}, (_, i) => i); // [0, 1, 2, 3, 4]
// Accessing and modifying
arr[0] = 100;
let last = arr[arr.length - 1]; // 5
// Common methods
arr.push(6); // Add to end
arr.pop(); // Remove from end
arr.unshift(0); // Add to beginning
arr.shift(); // Remove from beginning
arr.splice(2, 1, 99); // Remove 1 element at index 2, insert 99
// Iteration
arr.forEach((element, index) => {
console.log(index, element);
});
// Map, filter, reduce
let doubled = arr.map(x => x * 2);
let evens = arr.filter(x => x % 2 === 0);
let sum = arr.reduce((acc, x) => acc + x, 0);
// Find elements
let found = arr.find(x => x > 3); // First element > 3
let index = arr.findIndex(x => x > 3); // Index of first element > 3
let includes = arr.includes(3); // true if 3 exists
// Sorting
arr.sort((a, b) => a - b); // Ascending
arr.sort((a, b) => b - a); // Descending
// Spread operator
let combined = [...arr, ...arr2];
let copy = [...arr];
C++
#include <iostream>
#include <vector>
#include <array>
using namespace std;
int main() {
// Static array
int arr[5] = {1, 2, 3, 4, 5};
int size = sizeof(arr) / sizeof(arr[0]); // 5
// Access and modify
arr[0] = 100;
int last = arr[size - 1];
// std::array (fixed size, safer)
array<int, 5> std_arr = {1, 2, 3, 4, 5};
std_arr[0] = 100;
int sz = std_arr.size();
// std::vector (dynamic array)
vector<int> vec = {1, 2, 3, 4, 5};
vec.push_back(6); // Add to end
vec.pop_back(); // Remove from end
vec.insert(vec.begin() + 2, 99); // Insert at index 2
vec.erase(vec.begin() + 2); // Remove at index 2
// Iteration
for (int i = 0; i < vec.size(); i++) {
cout << vec[i] << " ";
}
// Range-based for loop
for (int x : vec) {
cout << x << " ";
}
// 2D vector
vector<vector<int>> matrix(3, vector<int>(4, 0)); // 3x4 matrix of zeros
matrix[1][2] = 99;
return 0;
}
Java
import java.util.ArrayList;
import java.util.Arrays;
public class ArrayExamples {
public static void main(String[] args) {
// Static array
int[] arr = {1, 2, 3, 4, 5};
int[] arr2 = new int[10]; // 10 elements, initialized to 0
// Access and modify
arr[0] = 100;
int length = arr.length;
// ArrayList (dynamic array)
ArrayList<Integer> list = new ArrayList<>();
list.add(1);
list.add(2);
list.add(3);
list.add(2, 99); // Insert at index 2
list.remove(2); // Remove at index 2
int element = list.get(1); // Access index 1
list.set(1, 100); // Modify index 1
// Iteration
for (int i = 0; i < list.size(); i++) {
System.out.println(list.get(i));
}
for (int x : list) {
System.out.println(x);
}
// Useful methods
boolean contains = list.contains(3);
int idx = list.indexOf(3);
list.sort((a, b) -> a - b); // Sort
// Arrays utility
int[] arr3 = {3, 1, 4, 1, 5};
Arrays.sort(arr3);
int index = Arrays.binarySearch(arr3, 4); // Binary search (sorted array)
}
}
Common Algorithms
Linear Search
def linear_search(arr, target):
for i in range(len(arr)):
if arr[i] == target:
return i
return -1 # Not found
# Time: $O(n)$, Space: $O(1)$
Binary Search (Sorted Array)
def binary_search(arr, target):
left, right = 0, len(arr) - 1
while left <= right:
mid = left + (right - left) // 2 # Avoid overflow
if arr[mid] == target:
return mid
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return -1 # Not found
# Time: $O(\log n)$, Space: $O(1)$
Two Pointers Technique
def two_sum_sorted(arr, target):
"""Find two numbers that sum to target in sorted array"""
left, right = 0, len(arr) - 1
while left < right:
current_sum = arr[left] + arr[right]
if current_sum == target:
return [left, right]
elif current_sum < target:
left += 1
else:
right -= 1
return [] # Not found
Sliding Window
def max_sum_subarray(arr, k):
"""Maximum sum of k consecutive elements"""
if len(arr) < k:
return None
# Compute sum of first window
window_sum = sum(arr[:k])
max_sum = window_sum
# Slide the window
for i in range(k, len(arr)):
window_sum = window_sum - arr[i - k] + arr[i]
max_sum = max(max_sum, window_sum)
return max_sum
# Time: $O(n)$, Space: $O(1)$
Kadane's Algorithm (Maximum Subarray)
def max_subarray_sum(arr):
"""Find maximum sum of any contiguous subarray"""
max_so_far = arr[0]
max_ending_here = arr[0]
for i in range(1, len(arr)):
max_ending_here = max(arr[i], max_ending_here + arr[i])
max_so_far = max(max_so_far, max_ending_here)
return max_so_far
# Time: $O(n)$, Space: $O(1)$
# Example: [-2, 1, -3, 4, -1, 2, 1, -5, 4] -> 6 (subarray [4, -1, 2, 1])
Common Problems
Reverse an Array
def reverse(arr):
left, right = 0, len(arr) - 1
while left < right:
arr[left], arr[right] = arr[right], arr[left]
left += 1
right -= 1
return arr
Rotate Array
def rotate_right(arr, k):
"""Rotate array to the right by k positions"""
n = len(arr)
k = k % n # Handle k > n
# Reverse entire array
arr.reverse()
# Reverse first k elements
arr[:k] = reversed(arr[:k])
# Reverse remaining elements
arr[k:] = reversed(arr[k:])
return arr
# Example: [1, 2, 3, 4, 5], k=2 -> [4, 5, 1, 2, 3]
Remove Duplicates (Sorted Array)
def remove_duplicates(arr):
"""Remove duplicates in-place, return new length"""
if not arr:
return 0
write_index = 1
for i in range(1, len(arr)):
if arr[i] != arr[i - 1]:
arr[write_index] = arr[i]
write_index += 1
return write_index
Best Practices
1. Bounds Checking
# Bad
if arr[i] > 0: # May cause IndexError
# Good
if i < len(arr) and arr[i] > 0:
2. Avoid Modifying While Iterating
# Bad
for x in arr:
if x % 2 == 0:
arr.remove(x) # Causes skipping
# Good
arr = [x for x in arr if x % 2 != 0] # Create new list
3. Use Appropriate Data Structure
- Need frequent insertions/deletions? Consider linked list
- Need fast lookups? Consider hash table
- Working with numerical data? Use NumPy arrays
ELI10
Think of an array like a row of mailboxes in an apartment building:
- Each mailbox has a number (index): 0, 1, 2, 3, ...
- Each mailbox can hold one item (element)
- To get mail from mailbox #3, you go directly to it - very fast!
- All mailboxes are right next to each other in a line
- You know exactly how many mailboxes there are
The cool part: You can instantly go to any mailbox by its number, no need to check all the other mailboxes first!
The tricky part: If you want to add a new mailbox in the middle, you have to shift all the mailboxes after it to make room - that takes time!
Further Resources
- Arrays in Python - Official Docs
- MDN JavaScript Arrays
- C++ Vector Documentation
- LeetCode Array Problems
Linked Lists
Overview
A linked list is a linear data structure where elements (nodes) are connected via pointers/references rather than stored in contiguous memory. Each node contains data and a reference to the next node, creating a chain-like structure.
Key Concepts
Structure
Head
|
Data -> Next -> Data -> Next -> Data -> Next -> None
Types
| Type | Description | Use Case |
|---|---|---|
| Singly Linked List | Each node points to next node only | Standard, memory efficient |
| Doubly Linked List | Each node points to next and previous | Need bidirectional traversal |
| Circular Linked List | Last node points back to first | Round-robin scheduling |
Advantages vs Arrays
| Feature | Linked List | Array |
|---|---|---|
| Access | $O(n)$ | $O(1)$ |
| Insert/Delete at start | $O(1)$ | $O(n)$ |
| Insert/Delete in middle | $O(n)$ to find, $O(1)$ to insert | $O(n)$ |
| Memory | Flexible, dynamic | Fixed or expensive to resize |
| Cache Efficiency | Poor | Excellent |
Implementation
Python - Singly Linked List
class Node:
def __init__(self, data):
self.data = data
self.next = None
class LinkedList:
def __init__(self):
self.head = None
def append(self, data):
"""Add element to end"""
new_node = Node(data)
if not self.head:
self.head = new_node
return
current = self.head
while current.next:
current = current.next
current.next = new_node
def prepend(self, data):
"""Add element to beginning"""
new_node = Node(data)
new_node.next = self.head
self.head = new_node
def insert_after(self, prev_data, data):
"""Insert after specific value"""
current = self.head
while current and current.data != prev_data:
current = current.next
if current:
new_node = Node(data)
new_node.next = current.next
current.next = new_node
def delete(self, data):
"""Remove first occurrence"""
if not self.head:
return
# If head needs to be deleted
if self.head.data == data:
self.head = self.head.next
return
current = self.head
while current.next and current.next.data != data:
current = current.next
if current.next:
current.next = current.next.next
def search(self, data):
"""Find element"""
current = self.head
while current:
if current.data == data:
return True
current = current.next
return False
def display(self):
"""Print all elements"""
elements = []
current = self.head
while current:
elements.append(str(current.data))
current = current.next
print(" -> ".join(elements) + " -> None")
def __len__(self):
"""Get length"""
count = 0
current = self.head
while current:
count += 1
current = current.next
return count
# Usage
ll = LinkedList()
ll.append(1)
ll.append(2)
ll.append(3)
ll.prepend(0)
ll.display() # 0 -> 1 -> 2 -> 3 -> None
ll.delete(2)
ll.display() # 0 -> 1 -> 3 -> None
Python - Doubly Linked List
class DNode:
def __init__(self, data):
self.data = data
self.next = None
self.prev = None
class DoublyLinkedList:
def __init__(self):
self.head = None
def append(self, data):
"""Add to end"""
new_node = DNode(data)
if not self.head:
self.head = new_node
return
current = self.head
while current.next:
current = current.next
current.next = new_node
new_node.prev = current
def reverse_display(self):
"""Print in reverse"""
if not self.head:
return
current = self.head
while current.next:
current = current.next
elements = []
while current:
elements.append(str(current.data))
current = current.prev
print(" -> ".join(elements) + " -> None")
JavaScript
class Node {
constructor(data) {
this.data = data;
this.next = null;
}
}
class LinkedList {
constructor() {
this.head = null;
}
append(data) {
const newNode = new Node(data);
if (!this.head) {
this.head = newNode;
return;
}
let current = this.head;
while (current.next) {
current = current.next;
}
current.next = newNode;
}
prepend(data) {
const newNode = new Node(data);
newNode.next = this.head;
this.head = newNode;
}
delete(data) {
if (!this.head) return;
if (this.head.data === data) {
this.head = this.head.next;
return;
}
let current = this.head;
while (current.next && current.next.data !== data) {
current = current.next;
}
if (current.next) {
current.next = current.next.next;
}
}
display() {
let current = this.head;
let result = [];
while (current) {
result.push(current.data);
current = current.next;
}
console.log(result.join(" -> ") + " -> null");
}
}
// Usage
const ll = new LinkedList();
ll.append(1);
ll.append(2);
ll.prepend(0);
ll.display(); // 0 -> 1 -> 2 -> null
C++
#include <iostream>
using namespace std;
struct Node {
int data;
Node* next;
Node(int data) : data(data), next(nullptr) {}
};
class LinkedList {
private:
Node* head;
public:
LinkedList() : head(nullptr) {}
void append(int data) {
Node* newNode = new Node(data);
if (!head) {
head = newNode;
return;
}
Node* current = head;
while (current->next) {
current = current->next;
}
current->next = newNode;
}
void prepend(int data) {
Node* newNode = new Node(data);
newNode->next = head;
head = newNode;
}
void deleteNode(int data) {
if (!head) return;
if (head->data == data) {
Node* temp = head;
head = head->next;
delete temp;
return;
}
Node* current = head;
while (current->next && current->next->data != data) {
current = current->next;
}
if (current->next) {
Node* temp = current->next;
current->next = current->next->next;
delete temp;
}
}
void display() {
Node* current = head;
while (current) {
cout << current->data << " -> ";
current = current->next;
}
cout << "null\n";
}
~LinkedList() {
Node* current = head;
while (current) {
Node* temp = current;
current = current->next;
delete temp;
}
}
};
Common Problems
Reverse a Linked List
def reverse(head):
"""Reverse entire linked list"""
prev = None
current = head
while current:
next_temp = current.next # Save next
current.next = prev # Reverse link
prev = current # Move prev forward
current = next_temp # Move current forward
return prev # New head
Find Middle
def find_middle(head):
"""Find middle node using slow/fast pointers"""
slow = fast = head
while fast and fast.next:
slow = slow.next
fast = fast.next.next
return slow # Slow pointer at middle
Detect Cycle
def has_cycle(head):
"""Detect if linked list has cycle"""
slow = fast = head
while fast and fast.next:
slow = slow.next
fast = fast.next.next
if slow == fast: # Cycle detected
return True
return False
Merge Two Sorted Lists
def merge_sorted(l1, l2):
"""Merge two sorted linked lists"""
dummy = Node(0)
current = dummy
while l1 and l2:
if l1.data < l2.data:
current.next = l1
l1 = l1.next
else:
current.next = l2
l2 = l2.next
current = current.next
# Attach remaining
current.next = l1 if l1 else l2
return dummy.next
Remove Nth Node from End
def remove_nth_from_end(head, n):
"""Remove nth node from end"""
dummy = Node(0)
dummy.next = head
first = second = dummy
# Move first pointer n+1 steps ahead
for i in range(n + 1):
if not first:
return head
first = first.next
# Move both until first reaches end
while first:
first = first.next
second = second.next
# Remove node
second.next = second.next.next
return dummy.next
Time Complexity Summary
| Operation | Singly | Doubly |
|---|---|---|
| Access | $O(n)$ | $O(n)$ |
| Search | $O(n)$ | $O(n)$ |
| Insert at head | $O(1)$ | $O(1)$ |
| Insert at tail | $O(n)$ | $O(1)$* |
| Delete from head | $O(1)$ | $O(1)$ |
| Delete from tail | $O(n)$ | $O(1)$* |
| Reverse | $O(n)$ | $O(n)$ |
*With tail pointer
Best Practices
1. Use Sentinel Nodes
# Bad: Check for None multiple times
if head and head.next and head.next.next:
...
# Good: Use dummy node
dummy = Node(0)
dummy.next = head
current = dummy
# Now no need to check if current exists
2. Avoid Memory Leaks (C++)
// Always delete removed nodes
Node* temp = current->next;
current->next = current->next->next;
delete temp; // Free memory
3. Two-Pointer Technique
# Many problems solved with slow/fast pointers:
# - Find middle
# - Detect cycle
# - Remove nth from end
ELI10
Imagine a treasure hunt with clues:
- Each clue card (node) has treasure info and points to the next clue
- You start at the first clue (head)
- To find a specific clue, you must follow the chain - you can't jump!
- To add a clue in the middle, you just change what one card points to
- You don't need a big board to write all clues - they can be anywhere!
The tricky part: You can only look at clues in order, you can't jump to the middle one directly like you could with an array.
Further Resources
Stacks
Overview
A stack is a Last-In-First-Out (LIFO) data structure where elements are added and removed from the same end, called the top. Think of it like a stack of dinner plates - you put plates on top and take them from the top.
Key Concepts
LIFO Principle
Push: 1 -> 2 -> 3
3 Top (Last In)
2
1 First In
Pop: Returns 3 (First Out)
Operations & Time Complexity
| Operation | Time | Space |
|---|---|---|
| Push | $O(1)$ | $O(n)$ |
| Pop | $O(1)$ | - |
| Peek | $O(1)$ | - |
| Is Empty | $O(1)$ | - |
| Search | $O(n)$ | - |
Implementation (Python)
class Stack:
def __init__(self):
self.items = []
def push(self, data):
self.items.append(data)
def pop(self):
return self.items.pop() if not self.is_empty() else None
def peek(self):
return self.items[-1] if not self.is_empty() else None
def is_empty(self):
return len(self.items) == 0
def size(self):
return len(self.items)
# Using deque for $O(1)$ operations
from collections import deque
stack = deque()
stack.append(1) # Push $O(1)$
stack.pop() # Pop $O(1)$
Common Problems
Valid Parentheses
def is_valid(s):
stack = []
pairs = {'(': ')', '[': ']', '{': '}'}
for char in s:
if char in pairs:
stack.append(char)
else:
if not stack or pairs[stack.pop()] != char:
return False
return len(stack) == 0
Next Greater Element
def next_greater(arr):
stack = []
result = [-1] * len(arr)
for i in range(len(arr) - 1, -1, -1):
while stack and stack[-1] <= arr[i]:
stack.pop()
if stack:
result[i] = stack[-1]
stack.append(arr[i])
return result
Postfix Evaluation
def evaluate_postfix(expr):
stack = []
ops = {'+', '-', '*', '/'}
for token in expr.split():
if token not in ops:
stack.append(int(token))
else:
b = stack.pop()
a = stack.pop()
if token == '+': stack.append(a + b)
elif token == '-': stack.append(a - b)
elif token == '*': stack.append(a * b)
else: stack.append(a // b)
return stack[0]
Real-World Uses
- Browser Back Button: Last visited page is first to go back to
- Function Call Stack: Each function call pushed, returns pop
- Undo/Redo: Last action undone first
- Expression Parsing: Manage operator precedence
- DFS (Depth-First Search): Graph traversal
ELI10
Imagine a stack of dinner plates - you:
- Add new plates on top
- Take plates from top
- Can't grab from the middle without removing top plates
That's LIFO! Last In = First Out. The last plate you put on is the first one you take off.
Further Resources
A queue is a linear data structure that follows the First In First Out (FIFO) principle. This means that the first element added to the queue will be the first one to be removed. Queues are commonly used in scenarios where order needs to be preserved, such as in scheduling tasks, managing requests in a server, or handling asynchronous data.
Key Operations
- Enqueue: Add an element to the end of the queue.
- Dequeue: Remove an element from the front of the queue.
- Peek/Front: Get the element at the front of the queue without removing it.
- IsEmpty: Check if the queue is empty.
- Size: Get the number of elements in the queue.
Types of Queues
- Simple Queue: Also known as a linear queue, where insertion happens at the rear and deletion happens at the front.
- Circular Queue: A more efficient queue where the last position is connected back to the first position to make a circle.
- Priority Queue: Each element is associated with a priority, and elements are served based on their priority.
- Double-ended Queue (Deque): Insertion and deletion can happen at both the front and the rear of the queue.
Applications of Queues
- CPU Scheduling: Managing processes in operating systems.
- Disk Scheduling: Managing I/O requests.
- Breadth-First Search (BFS): Traversing or searching tree or graph data structures.
- Print Queue: Managing print jobs in a printer.
Queues are fundamental data structures that are widely used in computer science and programming for managing ordered collections of items.
Hash Tables
Overview
A hash table (hash map) stores key-value pairs with $O(1)$ average-case lookup, insertion, and deletion. It uses a hash function to map keys to array indices.
How It Works
Hash Function
Converts key to index:
hash("name") = 5
hash(123) = 2
hash("email") = 5 (collision!)
Collision Resolution
Chaining: Store multiple values at same index
Index 0: None
Index 1: None
Index 2: 123 -> "John"
Index 3: 456 -> "Jane" -> 789 -> "Jack"
Open Addressing: Find next empty slot
hash("a") = 5 (occupied)
Try 6, 7, 8... until empty
Operations
| Operation | Average | Worst |
|---|---|---|
| Get | $O(1)$ | $O(n)$ |
| Set | $O(1)$ | $O(n)$ |
| Delete | $O(1)$ | $O(n)$ |
Python Implementation
# Built-in dict
d = {"key": "value"}
d.get("key") # $O(1)$
d["key"] = "new_value"
del d["key"]
# Custom
class HashTable:
def __init__(self, size=10):
self.table = [[] for _ in range(size)]
def set(self, key, value):
index = hash(key) % len(self.table)
for i, (k, v) in enumerate(self.table[index]):
if k == key:
self.table[index][i] = (key, value)
return
self.table[index].append((key, value))
def get(self, key):
index = hash(key) % len(self.table)
for k, v in self.table[index]:
if k == key:
return v
return None
Common Problems
Two Sum
def two_sum(arr, target):
seen = {}
for num in arr:
if target - num in seen:
return [seen[target - num], arr.index(num)]
seen[num] = arr.index(num)
return []
Duplicate Detection
def has_duplicates(arr):
return len(arr) != len(set(arr))
ELI10
Think of hash tables like library catalogs:
- Hash function = catalog system (tells you which shelf)
- Index = shelf number
- Value = book
Instead of searching every shelf, the system instantly tells you which one!
Further Resources
Tree Traversal Algorithms
Tree traversal algorithms are methods used to visit all the nodes in a tree data structure in a specific order. These algorithms are essential for various operations on trees, such as searching, sorting, and manipulating data. There are several types of tree traversal algorithms, each with its own use cases and characteristics.
Types of Tree Traversal Algorithms
1. Depth-First Search (DFS)
Depth-First Search (DFS) is a traversal algorithm that explores as far as possible along each branch before backtracking. There are three common types of DFS traversals:
a. Preorder Traversal
In preorder traversal, the nodes are visited in the following order:
- Visit the root node.
- Traverse the left subtree.
- Traverse the right subtree.
Use cases: Used for creating a copy of the tree, prefix expression evaluation, and serializing trees.
Implementation (Recursive):
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def preorder_traversal_recursive(root):
"""
Preorder traversal: Root -> Left -> Right
Time Complexity: $O(n)$ where n is the number of nodes
Space Complexity: $O(h)$ where h is the height (due to recursion stack)
"""
result = []
def traverse(node):
if not node:
return
result.append(node.val) # Visit root
traverse(node.left) # Traverse left subtree
traverse(node.right) # Traverse right subtree
traverse(root)
return result
Implementation (Iterative):
def preorder_traversal_iterative(root):
"""
Iterative preorder traversal using a stack.
Time Complexity: $O(n)$
Space Complexity: $O(h)$ in worst case, $O(\log n)$ for balanced tree
"""
if not root:
return []
result = []
stack = [root]
while stack:
node = stack.pop()
result.append(node.val)
# Push right first so left is processed first (stack is LIFO)
if node.right:
stack.append(node.right)
if node.left:
stack.append(node.left)
return result
Example:
Tree: 1
/ \
2 3
/ \
4 5
Preorder: [1, 2, 4, 5, 3]
Step-by-step:
1. Visit 1 (root)
2. Visit 2 (left child of 1)
3. Visit 4 (left child of 2)
4. Visit 5 (right child of 2)
5. Visit 3 (right child of 1)
b. Inorder Traversal
In inorder traversal, the nodes are visited in the following order:
- Traverse the left subtree.
- Visit the root node.
- Traverse the right subtree.
Use cases: For Binary Search Trees, inorder traversal gives nodes in sorted (ascending) order. Also used for expression tree evaluation.
Implementation (Recursive):
def inorder_traversal_recursive(root):
"""
Inorder traversal: Left -> Root -> Right
Time Complexity: $O(n)$
Space Complexity: $O(h)$ due to recursion stack
"""
result = []
def traverse(node):
if not node:
return
traverse(node.left) # Traverse left subtree
result.append(node.val) # Visit root
traverse(node.right) # Traverse right subtree
traverse(root)
return result
Implementation (Iterative):
def inorder_traversal_iterative(root):
"""
Iterative inorder traversal using a stack.
Time Complexity: $O(n)$
Space Complexity: $O(h)$
"""
result = []
stack = []
current = root
while current or stack:
# Go to the leftmost node
while current:
stack.append(current)
current = current.left
# Current is None, pop from stack
current = stack.pop()
result.append(current.val)
# Visit the right subtree
current = current.right
return result
Example:
Tree: 1
/ \
2 3
/ \
4 5
Inorder: [4, 2, 5, 1, 3]
Step-by-step:
1. Visit 4 (leftmost node)
2. Visit 2 (parent of 4)
3. Visit 5 (right child of 2)
4. Visit 1 (root)
5. Visit 3 (right child of 1)
For BST, this gives sorted order!
c. Postorder Traversal
In postorder traversal, the nodes are visited in the following order:
- Traverse the left subtree.
- Traverse the right subtree.
- Visit the root node.
Use cases: Used for deleting trees (delete children before parent), postfix expression evaluation, and calculating directory sizes.
Implementation (Recursive):
def postorder_traversal_recursive(root):
"""
Postorder traversal: Left -> Right -> Root
Time Complexity: $O(n)$
Space Complexity: $O(h)$
"""
result = []
def traverse(node):
if not node:
return
traverse(node.left) # Traverse left subtree
traverse(node.right) # Traverse right subtree
result.append(node.val) # Visit root
traverse(root)
return result
Implementation (Iterative):
def postorder_traversal_iterative(root):
"""
Iterative postorder traversal using two stacks.
Time Complexity: $O(n)$
Space Complexity: $O(h)$
"""
if not root:
return []
result = []
stack1 = [root]
stack2 = []
# Push nodes to stack2 in reverse postorder
while stack1:
node = stack1.pop()
stack2.append(node)
# Push left first, then right (opposite of preorder)
if node.left:
stack1.append(node.left)
if node.right:
stack1.append(node.right)
# Pop from stack2 to get postorder
while stack2:
result.append(stack2.pop().val)
return result
Example:
Tree: 1
/ \
2 3
/ \
4 5
Postorder: [4, 5, 2, 3, 1]
Step-by-step:
1. Visit 4 (leftmost leaf)
2. Visit 5 (right sibling of 4)
3. Visit 2 (parent of 4 and 5)
4. Visit 3 (leaf node)
5. Visit 1 (root, visited last)
2. Breadth-First Search (BFS) / Level Order Traversal
Breadth-First Search (BFS), also known as Level Order Traversal, is a traversal algorithm that explores all nodes at the present depth before moving to nodes at the next depth level. It uses a queue data structure.
Use cases: Finding shortest path in unweighted trees, level-by-level processing, serialization/deserialization of trees, finding all nodes at a given distance.
Implementation (Iterative):
from collections import deque
def level_order_traversal(root):
"""
Level order traversal using a queue (BFS).
Time Complexity: $O(n)$
Space Complexity: $O(w)$ where w is the maximum width of the tree
"""
if not root:
return []
result = []
queue = deque([root])
while queue:
node = queue.popleft()
result.append(node.val)
# Add children to queue
if node.left:
queue.append(node.left)
if node.right:
queue.append(node.right)
return result
Level-by-Level Implementation (returns list of lists):
def level_order_by_level(root):
"""
Returns nodes grouped by level.
Time Complexity: $O(n)$
Space Complexity: $O(w)$ where w is maximum width
"""
if not root:
return []
result = []
queue = deque([root])
while queue:
level_size = len(queue)
current_level = []
# Process all nodes at current level
for _ in range(level_size):
node = queue.popleft()
current_level.append(node.val)
if node.left:
queue.append(node.left)
if node.right:
queue.append(node.right)
result.append(current_level)
return result
Example:
Tree: 1
/ \
2 3
/ \ \
4 5 6
Level Order: [1, 2, 3, 4, 5, 6]
By Level: [[1], [2, 3], [4, 5, 6]]
Step-by-step:
Queue: [1] -> Visit 1, add children -> Result: [1]
Queue: [2, 3] -> Visit 2, add children -> Result: [1, 2]
Queue: [3, 4, 5] -> Visit 3, add children -> Result: [1, 2, 3]
Queue: [4, 5, 6] -> Visit 4 -> Result: [1, 2, 3, 4]
Queue: [5, 6] -> Visit 5 -> Result: [1, 2, 3, 4, 5]
Queue: [6] -> Visit 6 -> Result: [1, 2, 3, 4, 5, 6]
Variants:
def zigzag_level_order(root):
"""
Zigzag level order: alternate between left-to-right and right-to-left.
Example: [[1], [3, 2], [4, 5, 6]]
"""
if not root:
return []
result = []
queue = deque([root])
left_to_right = True
while queue:
level_size = len(queue)
current_level = deque()
for _ in range(level_size):
node = queue.popleft()
# Add to front or back based on direction
if left_to_right:
current_level.append(node.val)
else:
current_level.appendleft(node.val)
if node.left:
queue.append(node.left)
if node.right:
queue.append(node.right)
result.append(list(current_level))
left_to_right = not left_to_right
return result
def right_side_view(root):
"""
Return the values of nodes visible from the right side.
(Last node at each level)
"""
if not root:
return []
result = []
queue = deque([root])
while queue:
level_size = len(queue)
for i in range(level_size):
node = queue.popleft()
# Add last node of each level
if i == level_size - 1:
result.append(node.val)
if node.left:
queue.append(node.left)
if node.right:
queue.append(node.right)
return result
Binary Search Trees (BST)
A Binary Search Tree is a binary tree where for each node:
- All values in the left subtree are less than the node's value
- All values in the right subtree are greater than the node's value
- Both left and right subtrees are also BSTs
BST Operations
class BSTNode:
def __init__(self, val):
self.val = val
self.left = None
self.right = None
class BST:
def __init__(self):
self.root = None
def insert(self, val):
"""
Insert a value into the BST.
Time Complexity: $O(h)$ where h is height
Average: $O(\log n)$, Worst: $O(n)$ for skewed tree
"""
def _insert(node, val):
if not node:
return BSTNode(val)
if val < node.val:
node.left = _insert(node.left, val)
elif val > node.val:
node.right = _insert(node.right, val)
# If equal, don't insert (BST with unique values)
return node
self.root = _insert(self.root, val)
def search(self, val):
"""
Search for a value in the BST.
Time Complexity: $O(h)$
"""
def _search(node, val):
if not node or node.val == val:
return node
if val < node.val:
return _search(node.left, val)
else:
return _search(node.right, val)
return _search(self.root, val)
def delete(self, val):
"""
Delete a value from the BST.
Time Complexity: $O(h)$
"""
def _min_value_node(node):
"""Find the minimum value node in a subtree."""
current = node
while current.left:
current = current.left
return current
def _delete(node, val):
if not node:
return node
# Find the node to delete
if val < node.val:
node.left = _delete(node.left, val)
elif val > node.val:
node.right = _delete(node.right, val)
else:
# Node found! Handle three cases:
# Case 1: Node with only right child or no child
if not node.left:
return node.right
# Case 2: Node with only left child
if not node.right:
return node.left
# Case 3: Node with two children
# Get the inorder successor (smallest in right subtree)
successor = _min_value_node(node.right)
node.val = successor.val
node.right = _delete(node.right, successor.val)
return node
self.root = _delete(self.root, val)
def find_min(self):
"""Find minimum value (leftmost node)."""
if not self.root:
return None
current = self.root
while current.left:
current = current.left
return current.val
def find_max(self):
"""Find maximum value (rightmost node)."""
if not self.root:
return None
current = self.root
while current.right:
current = current.right
return current.val
def is_valid_bst(self):
"""
Validate if the tree is a valid BST.
Time Complexity: $O(n)$
"""
def _validate(node, min_val, max_val):
if not node:
return True
if node.val <= min_val or node.val >= max_val:
return False
return (_validate(node.left, min_val, node.val) and
_validate(node.right, node.val, max_val))
return _validate(self.root, float('-inf'), float('inf'))
# Example usage
bst = BST()
for val in [5, 3, 7, 2, 4, 6, 8]:
bst.insert(val)
print(bst.search(4)) # Found
print(bst.find_min()) # 2
print(bst.find_max()) # 8
BST Example:
5
/ \
3 7
/ \ / \
2 4 6 8
Inorder: [2, 3, 4, 5, 6, 7, 8] (sorted!)
Search for 4: 5 -> 3 -> 4 (3 steps)
Balanced Binary Search Trees
AVL Trees
AVL trees are self-balancing BSTs where the height difference between left and right subtrees (balance factor) is at most 1 for every node.
Balance Factor = height(left subtree) - height(right subtree)
- Must be in {-1, 0, 1}
class AVLNode:
def __init__(self, val):
self.val = val
self.left = None
self.right = None
self.height = 1 # Height of node
class AVLTree:
def get_height(self, node):
"""Get height of a node."""
if not node:
return 0
return node.height
def get_balance(self, node):
"""Get balance factor of a node."""
if not node:
return 0
return self.get_height(node.left) - self.get_height(node.right)
def update_height(self, node):
"""Update height of a node."""
if not node:
return 0
node.height = 1 + max(self.get_height(node.left),
self.get_height(node.right))
def rotate_right(self, y):
"""
Right rotation:
y x
/ \ / \
x C --> A y
/ \ / \
A B B C
"""
x = y.left
B = x.right
# Perform rotation
x.right = y
y.left = B
# Update heights
self.update_height(y)
self.update_height(x)
return x
def rotate_left(self, x):
"""
Left rotation:
x y
/ \ / \
A y --> x C
/ \ / \
B C A B
"""
y = x.right
B = y.left
# Perform rotation
y.left = x
x.right = B
# Update heights
self.update_height(x)
self.update_height(y)
return y
def insert(self, root, val):
"""
Insert a value and rebalance the tree.
Time Complexity: $O(\log n)$ - guaranteed!
"""
# 1. Perform standard BST insert
if not root:
return AVLNode(val)
if val < root.val:
root.left = self.insert(root.left, val)
elif val > root.val:
root.right = self.insert(root.right, val)
else:
return root # Duplicate values not allowed
# 2. Update height of current node
self.update_height(root)
# 3. Get balance factor
balance = self.get_balance(root)
# 4. If unbalanced, there are 4 cases:
# Left-Left Case
if balance > 1 and val < root.left.val:
return self.rotate_right(root)
# Right-Right Case
if balance < -1 and val > root.right.val:
return self.rotate_left(root)
# Left-Right Case
if balance > 1 and val > root.left.val:
root.left = self.rotate_left(root.left)
return self.rotate_right(root)
# Right-Left Case
if balance < -1 and val < root.right.val:
root.right = self.rotate_right(root.right)
return self.rotate_left(root)
return root
AVL Tree Rotations Explained:
Left-Left (LL) Imbalance:
Insert 1, 2, 3 into BST creates:
3 2
/ / \
2 --> 1 3
/
1
(Right rotation at 3)
Right-Right (RR) Imbalance:
Insert 3, 2, 1:
1 2
\ / \
2 --> 1 3
\
3
(Left rotation at 1)
Left-Right (LR) Imbalance:
3 3 2
/ / / \
1 --> 2 --> 1 3
\ /
2 1
(Left at 1, then Right at 3)
Right-Left (RL) Imbalance:
1 1 2
\ \ / \
3 --> 2 --> 1 3
/ \
2 3
(Right at 3, then Left at 1)
Common Tree Problems and Patterns
1. Tree Height/Depth
def max_depth(root):
"""
Find the maximum depth of a binary tree.
Time: $O(n)$, Space: $O(h)$
"""
if not root:
return 0
return 1 + max(max_depth(root.left), max_depth(root.right))
def min_depth(root):
"""
Find the minimum depth (root to nearest leaf).
Time: $O(n)$, Space: $O(h)$
"""
if not root:
return 0
if not root.left and not root.right:
return 1
if not root.left:
return 1 + min_depth(root.right)
if not root.right:
return 1 + min_depth(root.left)
return 1 + min(min_depth(root.left), min_depth(root.right))
2. Tree Diameter
def diameter_of_binary_tree(root):
"""
The diameter is the length of the longest path between any two nodes.
The path may or may not pass through the root.
Time: $O(n)$, Space: $O(h)$
"""
diameter = [0]
def height(node):
if not node:
return 0
left_height = height(node.left)
right_height = height(node.right)
# Update diameter (path through this node)
diameter[0] = max(diameter[0], left_height + right_height)
return 1 + max(left_height, right_height)
height(root)
return diameter[0]
3. Path Sum Problems
def has_path_sum(root, target_sum):
"""
Check if tree has root-to-leaf path that sums to target.
Time: $O(n)$, Space: $O(h)$
"""
if not root:
return False
if not root.left and not root.right:
return root.val == target_sum
remaining = target_sum - root.val
return (has_path_sum(root.left, remaining) or
has_path_sum(root.right, remaining))
def path_sum_all(root, target_sum):
"""
Find all root-to-leaf paths that sum to target.
Time: $O(n)$, Space: $O(h)$
"""
result = []
def dfs(node, current_sum, path):
if not node:
return
path.append(node.val)
current_sum += node.val
# Check if leaf node with target sum
if not node.left and not node.right and current_sum == target_sum:
result.append(path[:])
dfs(node.left, current_sum, path)
dfs(node.right, current_sum, path)
path.pop() # Backtrack
dfs(root, 0, [])
return result
4. Lowest Common Ancestor (LCA)
def lowest_common_ancestor(root, p, q):
"""
Find the lowest common ancestor of two nodes in a binary tree.
Time: $O(n)$, Space: $O(h)$
"""
if not root or root == p or root == q:
return root
left = lowest_common_ancestor(root.left, p, q)
right = lowest_common_ancestor(root.right, p, q)
# If both left and right are non-null, root is the LCA
if left and right:
return root
# Otherwise, return the non-null child
return left if left else right
def lca_bst(root, p, q):
"""
LCA for Binary Search Tree (more efficient).
Time: $O(h)$, Space: $O(1)$ iterative
"""
while root:
# Both nodes are in left subtree
if p.val < root.val and q.val < root.val:
root = root.left
# Both nodes are in right subtree
elif p.val > root.val and q.val > root.val:
root = root.right
# We've found the split point
else:
return root
5. Serialize and Deserialize
def serialize(root):
"""
Serialize a binary tree to a string.
Time: $O(n)$, Space: $O(n)$
"""
def dfs(node):
if not node:
return "None,"
return str(node.val) + "," + dfs(node.left) + dfs(node.right)
return dfs(root)
def deserialize(data):
"""
Deserialize a string to a binary tree.
Time: $O(n)$, Space: $O(n)$
"""
def dfs(values):
val = next(values)
if val == "None":
return None
node = TreeNode(int(val))
node.left = dfs(values)
node.right = dfs(values)
return node
return dfs(iter(data.split(",")))
6. Construct Trees from Traversals
def build_tree_from_inorder_preorder(preorder, inorder):
"""
Construct binary tree from preorder and inorder traversals.
Time: $O(n)$, Space: $O(n)$
"""
if not preorder or not inorder:
return None
# First element in preorder is the root
root_val = preorder[0]
root = TreeNode(root_val)
# Find root in inorder to split left/right subtrees
mid = inorder.index(root_val)
# Recursively build left and right subtrees
root.left = build_tree_from_inorder_preorder(
preorder[1:mid+1],
inorder[:mid]
)
root.right = build_tree_from_inorder_preorder(
preorder[mid+1:],
inorder[mid+1:]
)
return root
7. Tree Symmetry
def is_symmetric(root):
"""
Check if a tree is symmetric (mirror of itself).
Time: $O(n)$, Space: $O(h)$
"""
def is_mirror(left, right):
if not left and not right:
return True
if not left or not right:
return False
return (left.val == right.val and
is_mirror(left.left, right.right) and
is_mirror(left.right, right.left))
return is_mirror(root, root)
8. Flatten Tree to Linked List
def flatten_to_linked_list(root):
"""
Flatten binary tree to a linked list (preorder).
Time: $O(n)$, Space: $O(1)$
"""
if not root:
return
current = root
while current:
if current.left:
# Find the rightmost node of left subtree
rightmost = current.left
while rightmost.right:
rightmost = rightmost.right
# Connect it to current's right
rightmost.right = current.right
current.right = current.left
current.left = None
current = current.right
Complexity Cheat Sheet
| Operation | BST Average | BST Worst | AVL Tree | Red-Black Tree |
|---|---|---|---|---|
| Search | $O(\log n)$ | $O(n)$ | $O(\log n)$ | $O(\log n)$ |
| Insert | $O(\log n)$ | $O(n)$ | $O(\log n)$ | $O(\log n)$ |
| Delete | $O(\log n)$ | $O(n)$ | $O(\log n)$ | $O(\log n)$ |
| Space | $O(n)$ | $O(n)$ | $O(n)$ | $O(n)$ |
| Traversal | Time | Space |
|---|---|---|
| DFS (all) | $O(n)$ | $O(h)$ |
| BFS | $O(n)$ | $O(w)$ |
where:
- n = number of nodes
- h = height of tree
- w = maximum width of tree
Tips and Best Practices
When to Use Which Traversal?
-
Preorder (Root → Left → Right):
- Creating a copy of the tree
- Prefix expression of an expression tree
- Serialization of a tree
-
Inorder (Left → Root → Right):
- Getting sorted order from BST
- Validating BST
- Finding kth smallest element in BST
-
Postorder (Left → Right → Root):
- Deleting a tree (delete children before parent)
- Postfix expression evaluation
- Calculating size/height of subtrees
-
Level Order (BFS):
- Finding shortest path
- Level-by-level processing
- Finding nodes at distance k
- Checking if tree is complete
Common Patterns
- Two Pointer Pattern: Use two recursive calls to traverse both sides (LCA, tree symmetry)
- Path Tracking: Use backtracking to track paths (path sum, all paths)
- Bottom-Up: Process children first, then parent (tree diameter, balanced tree check)
- Level Processing: Process one level at a time (level order variants)
- Divide and Conquer: Split problem into left and right subtrees (construct tree from traversals)
Interview Tips
-
Always ask about tree properties:
- Is it a BST?
- Is it balanced?
- Can it have duplicate values?
- Is it a complete/perfect binary tree?
-
Common edge cases to consider:
- Empty tree (root is None)
- Single node tree
- Skewed tree (all left or all right)
- Complete binary tree
- Perfect binary tree
-
Space vs Time tradeoffs:
- Recursive solutions: Clean code but $O(h)$ stack space
- Iterative solutions: More complex but explicit stack control
- Morris Traversal: $O(1)$ space but modifies tree temporarily
-
Optimization techniques:
- Early termination when answer is found
- Use BST property to skip half the tree
- Cache results to avoid recomputation
- Use iterative DP for bottom-up approaches
Morris Traversal ($O(1)$ Space)
For space-constrained environments, Morris Traversal allows inorder traversal with $O(1)$ space by temporarily modifying the tree:
def morris_inorder_traversal(root):
"""
Inorder traversal with $O(1)$ space.
Temporarily modifies tree structure but restores it.
Time: $O(n)$, Space: $O(1)$
"""
result = []
current = root
while current:
if not current.left:
# No left subtree, visit current and go right
result.append(current.val)
current = current.right
else:
# Find inorder predecessor
predecessor = current.left
while predecessor.right and predecessor.right != current:
predecessor = predecessor.right
if not predecessor.right:
# Create temporary link
predecessor.right = current
current = current.left
else:
# Remove temporary link
predecessor.right = None
result.append(current.val)
current = current.right
return result
Conclusion
Trees are fundamental data structures in computer science with wide-ranging applications:
- Tree traversals provide different ways to visit nodes, each with specific use cases
- Binary Search Trees enable efficient searching, insertion, and deletion operations
- Balanced trees (AVL, Red-Black) guarantee $O(\log n)$ operations even in worst case
- Understanding tree patterns is crucial for solving complex algorithmic problems
Key takeaways:
- Master all four traversal methods (preorder, inorder, postorder, level order)
- Understand both recursive and iterative implementations
- Practice common tree problems to recognize patterns
- Know when to use which tree data structure for optimal performance
- Always consider edge cases and space-time tradeoffs
With solid understanding of tree algorithms, you'll be well-equipped to tackle a wide variety of programming challenges!
Graphs
Graphs are a fundamental data structure used to represent relationships between pairs of objects. They consist of vertices (or nodes) and edges (connections between the nodes). Graphs can be directed or undirected, weighted or unweighted, and are widely used in various applications such as social networks, transportation systems, and computer networks.
Key Concepts
- Vertices: The individual elements or nodes in a graph.
- Edges: The connections between the vertices, which can represent relationships or paths.
- Directed Graphs: Graphs where the edges have a direction, indicating a one-way relationship.
- Undirected Graphs: Graphs where the edges do not have a direction, indicating a two-way relationship.
- Weighted Graphs: Graphs where edges have weights, representing costs or distances associated with the connections.
Common Algorithms
-
Depth-First Search (DFS): A traversal algorithm that explores as far as possible along each branch before backtracking. It can be implemented using recursion or a stack.
-
Breadth-First Search (BFS): A traversal algorithm that explores all neighbors at the present depth prior to moving on to nodes at the next depth level. It is typically implemented using a queue.
-
Dijkstra's Algorithm: An algorithm for finding the shortest paths between nodes in a weighted graph, which may represent, for example, road networks.
-
A Search Algorithm*: An extension of Dijkstra's algorithm that uses heuristics to improve the efficiency of pathfinding.
Applications
Graphs are used in various applications, including:
- Social Networks: Representing users as vertices and their relationships as edges.
- Routing Algorithms: Finding the shortest path in navigation systems.
- Network Topology: Analyzing the structure of computer networks.
Conclusion
Graphs are a versatile and powerful data structure that can model complex relationships and interactions. Understanding graph theory and its associated algorithms is essential for solving a wide range of problems in computer science and software engineering.
Heaps
Heaps are a special tree-based data structure that satisfies the heap property. In a max heap, for any given node, the value of the node is greater than or equal to the values of its children, while in a min heap, the value of the node is less than or equal to the values of its children. Heaps are commonly used to implement priority queues and for efficient sorting algorithms.
Key Concepts
-
Heap Property: The key property that defines a heap, ensuring that the parent node is either greater than (max heap) or less than (min heap) its children.
-
Complete Binary Tree: Heaps are typically implemented as complete binary trees, where all levels are fully filled except possibly for the last level, which is filled from left to right.
Common Operations
-
Insertion: Adding a new element to the heap while maintaining the heap property. This is typically done by adding the element at the end of the tree and then "bubbling up" to restore the heap property.
-
Deletion: Removing the root element (the maximum or minimum) from the heap. This involves replacing the root with the last element in the tree and then "bubbling down" to restore the heap property.
-
Heapify: The process of converting an arbitrary array into a heap. This can be done in linear time using the bottom-up approach.
Applications
Heaps are widely used in various applications, including:
-
Priority Queues: Heaps provide an efficient way to implement priority queues, allowing for quick access to the highest (or lowest) priority element.
-
Heap Sort: A comparison-based sorting algorithm that uses the heap data structure to sort elements in $O(n \log n)$ time.
-
Graph Algorithms: Heaps are used in algorithms like Dijkstra's and Prim's to efficiently manage the set of vertices being processed.
Conclusion
Heaps are a versatile data structure that provides efficient solutions for various problems, particularly those involving priority management and sorting. Understanding heaps and their operations is essential for developing efficient algorithms in computer science and software engineering.
Tries
A trie, also known as a prefix tree, is a specialized tree data structure used to store associative data structures. A common application of a trie is storing a predictive text or autocomplete dictionary.
Key Concepts
-
Nodes: Each node in a trie represents a single character of a string. The root node represents an empty string.
-
Edges: The connections between nodes represent the characters that make up the strings stored in the trie.
-
Words: A word is formed by traversing from the root to a node that marks the end of a string.
Common Operations
-
Insertion: Adding a new word to the trie involves creating nodes for each character in the word and linking them together.
-
Search: To check if a word exists in the trie, traverse the nodes according to the characters in the word. If you reach the end of the word and the last node is marked as a complete word, the word exists in the trie.
-
Deletion: Removing a word from the trie involves traversing to the end of the word and removing nodes if they are no longer part of any other words.
Applications
Tries are widely used in various applications, including:
-
Autocomplete Systems: Providing suggestions based on the prefix of the input string.
-
Spell Checkers: Checking the validity of words against a dictionary.
-
IP Routing: Storing routing tables in networking.
Conclusion
Tries are a powerful data structure for managing a dynamic set of strings, particularly useful for applications involving prefix searches and dictionary implementations. Understanding tries and their operations is essential for developing efficient algorithms in computer science and software engineering.
Algorithms
Overview
An algorithm is a step-by-step procedure or formula for solving a problem. In computer science, algorithms are fundamental to writing efficient and effective code. Understanding algorithms helps you choose the right approach for solving computational problems and optimize performance.
What is an Algorithm?
An algorithm must have these characteristics:
- Input: Zero or more inputs
- Output: At least one output
- Definiteness: Clear and unambiguous steps
- Finiteness: Must terminate after a finite number of steps
- Effectiveness: Steps must be basic enough to be executed
Algorithm Analysis
Time Complexity
Time complexity measures how the runtime of an algorithm grows with input size.
# O(1) - Constant time
def get_first_element(arr):
return arr[0] if arr else None
# O(n) - Linear time
def find_element(arr, target):
for elem in arr:
if elem == target:
return True
return False
# O(n²) - Quadratic time
def bubble_sort(arr):
n = len(arr)
for i in range(n):
for j in range(n - i - 1):
if arr[j] > arr[j + 1]:
arr[j], arr[j + 1] = arr[j + 1], arr[j]
return arr
# O(log n) - Logarithmic time
def binary_search(arr, target):
left, right = 0, len(arr) - 1
while left <= right:
mid = (left + right) // 2
if arr[mid] == target:
return mid
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return -1
# O(n log n) - Linearithmic time
def merge_sort(arr):
if len(arr) <= 1:
return arr
mid = len(arr) // 2
left = merge_sort(arr[:mid])
right = merge_sort(arr[mid:])
return merge(left, right)
Space Complexity
Space complexity measures the amount of memory an algorithm uses.
# O(1) space - In-place
def reverse_array_inplace(arr):
left, right = 0, len(arr) - 1
while left < right:
arr[left], arr[right] = arr[right], arr[left]
left += 1
right -= 1
# O(n) space - Additional array
def reverse_array_new(arr):
return arr[::-1]
# O(n) space - Recursion stack
def factorial(n):
if n <= 1:
return 1
return n * factorial(n - 1)
Algorithm Categories
1. Sorting Algorithms
Transform data into a specific order (ascending/descending).
Common Sorting Algorithms:
- Bubble Sort - O(n²)
- Selection Sort - O(n²)
- Insertion Sort - O(n²)
- Merge Sort - O(n log n)
- Quick Sort - O(n log n) average
- Heap Sort - O(n log n)
See: Sorting Algorithms
2. Searching Algorithms
Find specific elements in data structures.
Common Searching Algorithms:
- Linear Search - O(n)
- Binary Search - O(log n)
- Jump Search - O(√n)
- Interpolation Search - O(log log n) average
See: Searching Algorithms
3. Graph Algorithms
Solve problems related to graph structures.
Common Graph Algorithms:
- Breadth-First Search (BFS)
- Depth-First Search (DFS)
- Dijkstra's Algorithm
- Bellman-Ford Algorithm
- Floyd-Warshall Algorithm
- Kruskal's Algorithm
- Prim's Algorithm
See: Graph Algorithms
4. Tree Algorithms
Operations on tree data structures.
Common Tree Algorithms:
- Tree Traversals (Inorder, Preorder, Postorder, Level-order)
- Binary Search Tree Operations
- AVL Tree Balancing
- Red-Black Tree Operations
- Trie Operations
See: Tree Algorithms
5. Dynamic Programming
Break complex problems into simpler subproblems and store results.
Classic DP Problems:
- Fibonacci Sequence
- Longest Common Subsequence
- Knapsack Problem
- Matrix Chain Multiplication
- Edit Distance
See: Dynamic Programming
6. Greedy Algorithms
Make locally optimal choices at each step.
Common Greedy Problems:
- Activity Selection
- Huffman Coding
- Fractional Knapsack
- Coin Change (greedy variant)
- Job Sequencing
See: Greedy Algorithms
7. Divide and Conquer
Divide problem into subproblems, solve recursively, combine results.
Examples:
- Merge Sort
- Quick Sort
- Binary Search
- Strassen's Matrix Multiplication
- Closest Pair of Points
See: Divide and Conquer
8. Backtracking
Try all possibilities and backtrack when stuck.
Classic Problems:
- N-Queens Problem
- Sudoku Solver
- Permutations and Combinations
- Graph Coloring
- Hamiltonian Path
See: Backtracking
9. Recursion
Function calls itself to solve problems.
Examples:
- Factorial
- Fibonacci
- Tower of Hanoi
- Tree Traversals
- Divide and Conquer algorithms
See: Recursion
Common Algorithm Patterns
Two Pointers
# Find pair with given sum in sorted array
def find_pair_with_sum(arr, target):
left, right = 0, len(arr) - 1
while left < right:
current_sum = arr[left] + arr[right]
if current_sum == target:
return (arr[left], arr[right])
elif current_sum < target:
left += 1
else:
right -= 1
return None
# Remove duplicates from sorted array
def remove_duplicates(arr):
if not arr:
return 0
write_index = 1
for read_index in range(1, len(arr)):
if arr[read_index] != arr[read_index - 1]:
arr[write_index] = arr[read_index]
write_index += 1
return write_index
Sliding Window
# Maximum sum subarray of size k
def max_sum_subarray(arr, k):
if len(arr) < k:
return None
# Calculate sum of first window
window_sum = sum(arr[:k])
max_sum = window_sum
# Slide window
for i in range(k, len(arr)):
window_sum = window_sum - arr[i - k] + arr[i]
max_sum = max(max_sum, window_sum)
return max_sum
# Longest substring without repeating characters
def longest_unique_substring(s):
char_index = {}
max_length = 0
start = 0
for end in range(len(s)):
if s[end] in char_index and char_index[s[end]] >= start:
start = char_index[s[end]] + 1
char_index[s[end]] = end
max_length = max(max_length, end - start + 1)
return max_length
Fast and Slow Pointers
# Detect cycle in linked list
def has_cycle(head):
if not head:
return False
slow = fast = head
while fast and fast.next:
slow = slow.next
fast = fast.next.next
if slow == fast:
return True
return False
# Find middle of linked list
def find_middle(head):
slow = fast = head
while fast and fast.next:
slow = slow.next
fast = fast.next.next
return slow
Merge Intervals
# Merge overlapping intervals
def merge_intervals(intervals):
if not intervals:
return []
# Sort by start time
intervals.sort(key=lambda x: x[0])
merged = [intervals[0]]
for current in intervals[1:]:
last = merged[-1]
if current[0] <= last[1]:
# Overlapping - merge
merged[-1] = (last[0], max(last[1], current[1]))
else:
# Non-overlapping - add
merged.append(current)
return merged
Binary Search Pattern
# Find first occurrence
def find_first(arr, target):
left, right = 0, len(arr) - 1
result = -1
while left <= right:
mid = (left + right) // 2
if arr[mid] == target:
result = mid
right = mid - 1 # Continue searching left
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return result
# Find peak element
def find_peak(arr):
left, right = 0, len(arr) - 1
while left < right:
mid = (left + right) // 2
if arr[mid] > arr[mid + 1]:
right = mid
else:
left = mid + 1
return left
Problem-Solving Approach
1. Understand the Problem
- Read carefully
- Identify inputs and outputs
- Clarify constraints
- Consider edge cases
2. Plan Your Approach
- Think of similar problems
- Consider multiple solutions
- Analyze time/space complexity
- Choose appropriate data structures
3. Implement
- Write clean, readable code
- Use meaningful variable names
- Add comments for complex logic
- Handle edge cases
4. Test
- Test with sample inputs
- Test edge cases (empty, single element, large input)
- Test boundary conditions
- Verify correctness
5. Optimize
- Analyze bottlenecks
- Consider trade-offs
- Improve time/space complexity
- Refactor for clarity
Time Complexity Cheat Sheet
| Complexity | Name | Example |
|---|---|---|
| O(1) | Constant | Array access, hash table lookup |
| O(log n) | Logarithmic | Binary search |
| O(n) | Linear | Linear search, array traversal |
| O(n log n) | Linearithmic | Merge sort, quick sort (average) |
| O(n²) | Quadratic | Bubble sort, nested loops |
| O(n³) | Cubic | Triple nested loops |
| O(2ⁿ) | Exponential | Recursive fibonacci |
| O(n!) | Factorial | Permutations |
Space Complexity Considerations
- In-place algorithms: O(1) space - modify input directly
- Recursion: O(n) space for call stack
- Memoization: Trade space for time
- Auxiliary data structures: Arrays, hash tables, etc.
Interview Tips
Common Algorithm Questions
-
Arrays and Strings
- Two Sum
- Reverse String
- Longest Substring
- Array Rotation
-
Linked Lists
- Reverse Linked List
- Detect Cycle
- Merge Two Lists
- Find Middle
-
Trees and Graphs
- Tree Traversals
- Validate BST
- Lowest Common Ancestor
- Graph BFS/DFS
-
Dynamic Programming
- Fibonacci
- Climbing Stairs
- Coin Change
- Longest Increasing Subsequence
-
Sorting and Searching
- Binary Search variants
- Merge K Sorted Lists
- Find Kth Largest
- Quick Select
Best Practices
- Communication: Think aloud
- Clarification: Ask questions
- Examples: Work through examples
- Optimization: Discuss trade-offs
- Testing: Verify with test cases
- Edge Cases: Consider all scenarios
- Clean Code: Write readable code
- Time Management: Don't get stuck
Practice Resources
Online Platforms
- LeetCode
- HackerRank
- CodeSignal
- Project Euler
- Codeforces
- AtCoder
Books
- "Introduction to Algorithms" (CLRS)
- "Algorithm Design Manual" (Skiena)
- "Cracking the Coding Interview"
- "Elements of Programming Interviews"
Available Topics
Explore detailed guides for specific algorithm types:
- Big O Notation - Understanding algorithm complexity
- Sorting Algorithms - Comprehensive sorting guide
- Searching Algorithms - Various search techniques
- Graph Algorithms - Graph traversal and algorithms
- Tree Algorithms - Tree operations and traversals
- Dynamic Programming - DP patterns and problems
- Greedy Algorithms - Greedy approach and examples
- Divide and Conquer - D&C strategy
- Backtracking - Backtracking techniques
- Recursion - Recursive problem solving
- Heaps - Heap data structure and algorithms
- Tries - Trie data structure and applications
- Raft Consensus - Distributed consensus algorithm for replicated logs
Quick Reference
Most Important Algorithms to Know
Sorting:
- Quick Sort
- Merge Sort
- Heap Sort
Searching:
- Binary Search
- Depth-First Search (DFS)
- Breadth-First Search (BFS)
Graph:
- Dijkstra's Algorithm
- Topological Sort
- Union-Find
Dynamic Programming:
- 0/1 Knapsack
- Longest Common Subsequence
- Edit Distance
String:
- KMP Pattern Matching
- Rabin-Karp
- Trie Operations
Next Steps
- Review Big O Notation for complexity analysis
- Practice with Sorting and Searching
- Master Recursion fundamentals
- Explore Dynamic Programming
- Study Graph Algorithms and Trees
- Practice on coding platforms
- Participate in coding contests
- Review and optimize solutions
Remember: The key to mastering algorithms is consistent practice and understanding the underlying patterns. Start with fundamentals and gradually tackle more complex problems.
Big O Notation
Big O notation is a mathematical concept used to describe the performance or complexity of an algorithm. Specifically, it characterizes algorithms in terms of their time or space requirements in relation to the size of the input data. Understanding Big O notation is crucial for evaluating the efficiency of algorithms and making informed decisions about which algorithm to use in a given situation.
Key Concepts
-
Time Complexity: This refers to the amount of time an algorithm takes to complete as a function of the length of the input. It helps in understanding how the execution time increases with the size of the input.
-
Space Complexity: This refers to the amount of memory an algorithm uses in relation to the input size. It is important to consider both time and space complexity when analyzing an algorithm.
Common Big O Notations
$O(1)$ - Constant Time
The execution time does not change regardless of the input size.
def get_first_element(arr):
return arr[0] # Always one operation
def hash_lookup(dictionary, key):
return dictionary[key] # Constant time hash table lookup
# Example
arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
print(get_first_element(arr)) # $O(1)$
Examples: Array access, hash table operations, simple arithmetic
$O(\log n)$ - Logarithmic Time
The execution time grows logarithmically as the input size increases.
def binary_search(arr, target):
left, right = 0, len(arr) - 1
while left <= right:
mid = (left + right) // 2
if arr[mid] == target:
return mid
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return -1
# Example: With 1000 elements, only ~10 comparisons needed
arr = list(range(1000))
print(binary_search(arr, 742)) # $O(\log n)$
Examples: Binary search, balanced binary tree operations
$O(n)$ - Linear Time
The execution time grows linearly with the input size.
def linear_search(arr, target):
for i, element in enumerate(arr):
if element == target:
return i
return -1
def find_max(arr):
max_val = arr[0]
for num in arr:
if num > max_val:
max_val = num
return max_val
# Example
arr = [3, 1, 4, 1, 5, 9, 2, 6, 5, 3]
print(linear_search(arr, 9)) # $O(n)$
print(find_max(arr)) # $O(n)$
Examples: Linear search, array traversal, finding min/max
$O(n \log n)$ - Linearithmic Time
Common in efficient sorting algorithms.
def merge_sort(arr):
if len(arr) <= 1:
return arr
mid = len(arr) // 2
left = merge_sort(arr[:mid])
right = merge_sort(arr[mid:])
return merge(left, right)
def merge(left, right):
result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result
# Example
arr = [38, 27, 43, 3, 9, 82, 10]
print(merge_sort(arr)) # $O(n \log n)$
Examples: Merge sort, heap sort, quick sort (average case)
$O(n^2)$ - Quadratic Time
The execution time grows quadratically with the input size.
def bubble_sort(arr):
n = len(arr)
for i in range(n):
for j in range(0, n - i - 1):
if arr[j] > arr[j + 1]:
arr[j], arr[j + 1] = arr[j + 1], arr[j]
return arr
def find_duplicates_naive(arr):
duplicates = []
for i in range(len(arr)):
for j in range(i + 1, len(arr)):
if arr[i] == arr[j]:
duplicates.append(arr[i])
return duplicates
# Example
arr = [64, 34, 25, 12, 22, 11, 90]
print(bubble_sort(arr.copy())) # $O(n^2)$
Examples: Bubble sort, selection sort, insertion sort, nested loops
$O(2^n)$ - Exponential Time
The execution time doubles with each additional element.
def fibonacci_recursive(n):
if n <= 1:
return n
return fibonacci_recursive(n - 1) + fibonacci_recursive(n - 2)
def power_set(s):
if not s:
return [[]]
subsets = power_set(s[1:])
return subsets + [[s[0]] + subset for subset in subsets]
# Example (slow for large n!)
print(fibonacci_recursive(10)) # $O(2^n)$
print(power_set([1, 2, 3])) # $O(2^n)$
Examples: Recursive Fibonacci, generating all subsets
$O(n!)$ - Factorial Time
The execution time grows factorially with the input size.
def permutations(arr):
if len(arr) <= 1:
return [arr]
result = []
for i in range(len(arr)):
rest = arr[:i] + arr[i+1:]
for p in permutations(rest):
result.append([arr[i]] + p)
return result
# Example (very slow!)
print(permutations([1, 2, 3])) # $O(n!)$
# For n=10, this would generate 3,628,800 permutations!
Examples: Generating all permutations, traveling salesman (brute force)
Complexity Comparison
import time
import random
def compare_complexities(n):
# $O(1)$
start = time.time()
_ = n
o1_time = time.time() - start
# $O(\log n)$
start = time.time()
_ = n.bit_length()
olog_time = time.time() - start
# $O(n)$
start = time.time()
_ = sum(range(n))
on_time = time.time() - start
# $O(n \log n)$
start = time.time()
arr = list(range(n))
random.shuffle(arr)
_ = sorted(arr)
onlogn_time = time.time() - start
# $O(n^2)$
start = time.time()
for i in range(min(n, 1000)): # Limited to avoid long wait
for j in range(min(n, 1000)):
pass
on2_time = time.time() - start
print(f"n = {n}:")
print(f" $O(1)$: {o1_time:.6f}s")
print(f" $O(\log n)$: {olog_time:.6f}s")
print(f" $O(n)$: {on_time:.6f}s")
print(f" $O(n \log n)$:{onlogn_time:.6f}s")
print(f" $O(n^2)$: {on2_time:.6f}s (limited)")
# Example
compare_complexities(10000)
Space Complexity Examples
# $O(1)$ space - In-place
def reverse_array_inplace(arr):
left, right = 0, len(arr) - 1
while left < right:
arr[left], arr[right] = arr[right], arr[left]
left += 1
right -= 1
# $O(n)$ space - Additional array
def reverse_array_new(arr):
return arr[::-1]
# $O(n)$ space - Recursion stack
def factorial_recursive(n):
if n <= 1:
return 1
return n * factorial_recursive(n - 1)
# $O(n^2)$ space - 2D array
def create_matrix(n):
return [[0 for _ in range(n)] for _ in range(n)]
Best, Average, and Worst Case
Different scenarios can have different complexities:
def quick_sort(arr):
if len(arr) <= 1:
return arr
pivot = arr[len(arr) // 2]
left = [x for x in arr if x < pivot]
middle = [x for x in arr if x == pivot]
right = [x for x in arr if x > pivot]
return quick_sort(left) + middle + quick_sort(right)
# Quick Sort:
# Best case: $O(n \log n)$ - balanced partitions
# Average case: $O(n \log n)$
# Worst case: $O(n^2)$ - already sorted array
Analyzing Algorithm Complexity
def example_algorithm(arr):
n = len(arr)
# $O(1)$ - constant operations
first = arr[0]
last = arr[-1]
# $O(n)$ - single loop
total = sum(arr)
# $O(n^2)$ - nested loops
for i in range(n):
for j in range(n):
pass
# $O(n \log n)$ - sorting
sorted_arr = sorted(arr)
# Overall: $O(1) + O(n) + O(n^2) + O(n \log n) = O(n^2)$
# (Dominant term is $n^2$)
Big O Rules
- Drop constants: $O(2n) \to O(n)$
- Drop non-dominant terms: $O(n^2 + n) \to O(n^2)$
- Different inputs use different variables: $O(a + b)$ for two arrays
- Multiplication for nested: $O(a \times b)$ for nested loops over different arrays
# Rule 1: Drop constants
def example1(arr):
for item in arr: # $O(n)$
print(item)
for item in arr: # $O(n)$
print(item)
# Total: $O(2n) = O(n)$
# Rule 2: Drop non-dominant terms
def example2(arr):
for i in range(len(arr)): # $O(n)$
for j in range(len(arr)): # $O(n^2)$
print(i, j)
for item in arr: # $O(n)$
print(item)
# Total: $O(n^2 + n) = O(n^2)$
# Rule 3: Different inputs
def example3(arr1, arr2):
for item in arr1: # $O(a)$
print(item)
for item in arr2: # $O(b)$
print(item)
# Total: $O(a + b)$
# Rule 4: Multiplication for nested
def example4(arr1, arr2):
for item1 in arr1: # $O(a)$
for item2 in arr2: # $O(b)$
print(item1, item2)
# Total: $O(a \times b)$
Complexity Cheat Sheet
| Complexity | Name | Example Operations |
|---|---|---|
| $O(1)$ | Constant | Array access, hash lookup |
| $O(\log n)$ | Logarithmic | Binary search |
| $O(n)$ | Linear | Loop through array |
| $O(n \log n)$ | Linearithmic | Efficient sorting |
| $O(n^2)$ | Quadratic | Nested loops |
| $O(n^3)$ | Cubic | Triple nested loops |
| $O(2^n)$ | Exponential | Recursive Fibonacci |
| $O(n!)$ | Factorial | All permutations |
Growth Rates Visualization
For n = 100:
$O(1)$: 1 operation
$O(\log n)$: 7 operations
$O(n)$: 100 operations
$O(n \log n)$:700 operations
$O(n^2)$: 10,000 operations
$O(n^3)$: 1,000,000 operations
$O(2^n)$: $1.27 \times 10^{30}$ operations (intractable!)
$O(n!)$: $9.33 \times 10^{157}$ operations (impossible!)
Practical Tips
- Optimize bottlenecks: Focus on the most time-consuming parts
- Trade-offs: Sometimes $O(n)$ space can give $O(1)$ time (caching)
- Real-world considerations: Constants matter for small n
- Amortized analysis: Some operations are cheaper on average
- Choose appropriately: Don't over-optimize; $O(n^2)$ is fine for small n
Conclusion
Big O notation provides a high-level understanding of the efficiency of algorithms, allowing developers to compare and choose the most suitable algorithm for their needs. By analyzing both time and space complexity, one can make informed decisions that lead to better performance in software applications.
Recursion
Recursion is a programming technique where a function calls itself in order to solve a problem. It is often used to break down complex problems into simpler subproblems.
Key Concepts
-
Base Case: The condition under which the recursion ends. It prevents infinite loops and allows the function to return a result.
-
Recursive Case: The part of the function where the recursion occurs, typically involving a call to the same function with modified arguments.
Factorial
The factorial of a non-negative integer n is the product of all positive integers less than or equal to n.
def factorial(n):
# Base case
if n == 0 or n == 1:
return 1
# Recursive case
return n * factorial(n - 1)
# Example usage
print(factorial(5)) # Output: 120
Fibonacci Sequence
The Fibonacci sequence where each number is the sum of the two preceding ones.
# Simple recursion (exponential time)
def fibonacci(n):
if n <= 1:
return n
return fibonacci(n - 1) + fibonacci(n - 2)
# With memoization (linear time)
def fibonacci_memo(n, memo={}):
if n in memo:
return memo[n]
if n <= 1:
return n
memo[n] = fibonacci_memo(n - 1, memo) + fibonacci_memo(n - 2, memo)
return memo[n]
# Example usage
print(fibonacci(10)) # Output: 55
print(fibonacci_memo(100)) # Much faster for large n
Binary Search
Recursive implementation of binary search.
def binary_search(arr, target, left, right):
# Base case: element not found
if left > right:
return -1
mid = left + (right - left) // 2
# Base case: element found
if arr[mid] == target:
return mid
# Recursive cases
if arr[mid] > target:
return binary_search(arr, target, left, mid - 1)
else:
return binary_search(arr, target, mid + 1, right)
# Example usage
arr = [1, 3, 5, 7, 9, 11, 13]
result = binary_search(arr, 7, 0, len(arr) - 1)
print(f"Element found at index: {result}") # Output: 3
Sum of Array
Calculate the sum of all elements in an array recursively.
def array_sum(arr):
# Base case: empty array
if not arr:
return 0
# Recursive case: first element + sum of rest
return arr[0] + array_sum(arr[1:])
# Optimized with index
def array_sum_optimized(arr, index=0):
if index == len(arr):
return 0
return arr[index] + array_sum_optimized(arr, index + 1)
# Example usage
numbers = [1, 2, 3, 4, 5]
print(array_sum(numbers)) # Output: 15
Power Function
Calculate x raised to the power n.
# Simple recursion
def power(x, n):
if n == 0:
return 1
return x * power(x, n - 1)
# Optimized (divide and conquer)
def power_optimized(x, n):
if n == 0:
return 1
half = power_optimized(x, n // 2)
if n % 2 == 0:
return half * half
else:
return x * half * half
# Example usage
print(power(2, 10)) # Output: 1024
print(power_optimized(2, 10)) # Faster for large n
String Reversal
Reverse a string using recursion.
def reverse_string(s):
# Base case: empty or single character
if len(s) <= 1:
return s
# Recursive case: last char + reverse of rest
return s[-1] + reverse_string(s[:-1])
# Alternative implementation
def reverse_string_alt(s):
if len(s) == 0:
return s
return reverse_string_alt(s[1:]) + s[0]
# Example usage
print(reverse_string("hello")) # Output: "olleh"
Palindrome Check
Check if a string is a palindrome recursively.
def is_palindrome(s, left=0, right=None):
if right is None:
right = len(s) - 1
# Base cases
if left >= right:
return True
if s[left] != s[right]:
return False
# Recursive case
return is_palindrome(s, left + 1, right - 1)
# Example usage
print(is_palindrome("racecar")) # Output: True
print(is_palindrome("hello")) # Output: False
Tree Traversals
Recursive tree traversal algorithms.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
# Inorder traversal (left, root, right)
def inorder(root):
if root is None:
return []
return inorder(root.left) + [root.val] + inorder(root.right)
# Preorder traversal (root, left, right)
def preorder(root):
if root is None:
return []
return [root.val] + preorder(root.left) + preorder(root.right)
# Postorder traversal (left, right, root)
def postorder(root):
if root is None:
return []
return postorder(root.left) + postorder(root.right) + [root.val]
# Tree height
def tree_height(root):
if root is None:
return 0
return 1 + max(tree_height(root.left), tree_height(root.right))
# Example usage
# 1
# / \
# 2 3
# / \
# 4 5
root = TreeNode(1)
root.left = TreeNode(2)
root.right = TreeNode(3)
root.left.left = TreeNode(4)
root.left.right = TreeNode(5)
print("Inorder:", inorder(root)) # [4, 2, 5, 1, 3]
print("Preorder:", preorder(root)) # [1, 2, 4, 5, 3]
print("Postorder:", postorder(root))# [4, 5, 2, 3, 1]
print("Height:", tree_height(root)) # 3
Greatest Common Divisor (GCD)
Find GCD using Euclidean algorithm.
def gcd(a, b):
# Base case
if b == 0:
return a
# Recursive case
return gcd(b, a % b)
# Example usage
print(gcd(48, 18)) # Output: 6
Tower of Hanoi
Classic puzzle solved recursively.
def tower_of_hanoi(n, source, destination, auxiliary):
if n == 1:
print(f"Move disk 1 from {source} to {destination}")
return
# Move n-1 disks from source to auxiliary
tower_of_hanoi(n - 1, source, auxiliary, destination)
# Move nth disk from source to destination
print(f"Move disk {n} from {source} to {destination}")
# Move n-1 disks from auxiliary to destination
tower_of_hanoi(n - 1, auxiliary, destination, source)
# Example usage
tower_of_hanoi(3, 'A', 'C', 'B')
Flatten Nested List
Flatten a nested list structure.
def flatten(nested_list):
result = []
for item in nested_list:
if isinstance(item, list):
result.extend(flatten(item))
else:
result.append(item)
return result
# Example usage
nested = [1, [2, [3, 4], 5], 6, [7, 8]]
print(flatten(nested)) # Output: [1, 2, 3, 4, 5, 6, 7, 8]
Recursive Patterns
Common recursive patterns to recognize:
1. Linear Recursion
def linear_recursion(n):
if n == 0:
return 0
return n + linear_recursion(n - 1)
2. Binary Recursion
def binary_recursion(n):
if n <= 1:
return n
return binary_recursion(n - 1) + binary_recursion(n - 2)
3. Tail Recursion
def tail_recursion(n, accumulator=0):
if n == 0:
return accumulator
return tail_recursion(n - 1, accumulator + n)
Recursion vs Iteration
# Recursive factorial
def factorial_recursive(n):
if n <= 1:
return 1
return n * factorial_recursive(n - 1)
# Iterative factorial
def factorial_iterative(n):
result = 1
for i in range(1, n + 1):
result *= i
return result
Tips for Recursion
- Always define a base case: Prevents infinite recursion
- Make progress toward base case: Each recursive call should move closer to the base case
- Trust the recursion: Assume the recursive call works correctly for smaller inputs
- Consider stack depth: Deep recursion can cause stack overflow
- Use memoization: Cache results to avoid redundant calculations
- Know when to use iteration: Sometimes iteration is clearer and more efficient
Common Pitfalls
# BAD: No base case (infinite recursion)
def bad_recursion(n):
return 1 + bad_recursion(n - 1) # Never stops!
# BAD: Doesn't make progress
def bad_recursion2(n):
if n == 0:
return 0
return bad_recursion2(n) # n never changes!
# GOOD: Proper base case and progress
def good_recursion(n):
if n == 0:
return 0
return 1 + good_recursion(n - 1) # n decreases
Applications
Recursion is widely used in various applications, including:
- Tree Traversals: Navigating through tree data structures using recursive methods
- Backtracking Algorithms: Solving problems incrementally by trying partial solutions
- Dynamic Programming: Many DP problems can be solved using recursive approaches with memoization
- Divide and Conquer: Breaking problems into smaller subproblems
- Mathematical Computations: Factorials, Fibonacci, GCD, etc.
Conclusion
Recursion is a powerful tool in programming that allows for elegant solutions to complex problems. Understanding how to effectively use recursion is essential for developing efficient algorithms in computer science and software engineering.
Dynamic Programming
Overview
Dynamic Programming (DP) solves complex problems by breaking them into simpler subproblems, solving each once, and storing results to avoid recomputation. Essential for optimization problems.
Key Concepts
Optimal Substructure: Optimal solution built from optimal solutions of subproblems
Overlapping Subproblems: Same subproblem computed multiple times
Approaches
Memoization (Top-Down)
def fib(n, memo={}):
if n in memo:
return memo[n]
if n <= 1:
return n
memo[n] = fib(n-1, memo) + fib(n-2, memo)
return memo[n]
Tabulation (Bottom-Up)
def fib(n):
dp = [0] * (n + 1)
dp[1] = 1
for i in range(2, n + 1):
dp[i] = dp[i-1] + dp[i-2]
return dp[n]
Classic Problems
Coin Change
def coin_change(coins, amount):
dp = [float('inf')] * (amount + 1)
dp[0] = 0
for coin in coins:
for i in range(coin, amount + 1):
dp[i] = min(dp[i], dp[i - coin] + 1)
return dp[amount] if dp[amount] != float('inf') else -1
Knapsack
def knapsack(weights, values, capacity):
dp = [[0] * (capacity + 1) for _ in range(len(weights) + 1)]
for i in range(1, len(weights) + 1):
for w in range(capacity + 1):
if weights[i-1] <= w:
dp[i][w] = max(
values[i-1] + dp[i-1][w - weights[i-1]],
dp[i-1][w]
)
else:
dp[i][w] = dp[i-1][w]
return dp[len(weights)][capacity]
Longest Common Subsequence
def lcs(text1, text2):
m, n = len(text1), len(text2)
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(1, m + 1):
for j in range(1, n + 1):
if text1[i-1] == text2[j-1]:
dp[i][j] = dp[i-1][j-1] + 1
else:
dp[i][j] = max(dp[i-1][j], dp[i][j-1])
return dp[m][n]
Edit Distance
def editDistance(word1, word2):
m, n = len(word1), len(word2)
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m + 1):
dp[i][0] = i
for j in range(n + 1):
dp[0][j] = j
for i in range(1, m + 1):
for j in range(1, n + 1):
if word1[i-1] == word2[j-1]:
dp[i][j] = dp[i-1][j-1]
else:
dp[i][j] = 1 + min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1])
return dp[m][n]
Patterns
1D DP: dp[i] = f(dp[i-1], dp[i-2], ...)
2D DP: dp[i][j] = f(dp[i-1][j], dp[i][j-1], ...)
Complexity
| Problem | Naive | DP |
|---|---|---|
| Fibonacci | $O(2^n)$ | $O(n)$ |
| Coin Change | $O(n^m)$ | $O(n \times m)$ |
| Knapsack | $O(2^n)$ | $O(n \times w)$ |
ELI10
DP is like climbing stairs - memorize steps you've already calculated instead of redoing them!
Further Resources
-
[Key Concepts
-
Overlapping Subproblems: Dynamic programming is applicable when the problem can be broken down into smaller, overlapping subproblems that can be solved independently. The results of these subproblems are stored to avoid redundant calculations.
-
Optimal Substructure: A problem exhibits optimal substructure if an optimal solution to the problem can be constructed from optimal solutions to its subproblems. This property allows dynamic programming to build up solutions incrementally.
Techniques
-
Top-Down Approach (Memoization): This approach involves solving the problem recursively and storing the results of subproblems in a table (or cache) to avoid redundant calculations. When a subproblem is encountered again, the stored result is used instead of recalculating it.
-
Bottom-Up Approach (Tabulation): In this approach, the problem is solved iteratively by filling up a table based on previously computed values. This method typically starts with the smallest subproblems and builds up to the solution of the original problem.
Applications
Dynamic programming is widely used in various applications, including:
-
Fibonacci Sequence: Calculating Fibonacci numbers can be optimized using dynamic programming to avoid exponential time complexity.
-
Knapsack Problem: The 0/1 knapsack problem can be efficiently solved using dynamic programming techniques to maximize the total value of items that can be carried.
-
Longest Common Subsequence: Finding the longest common subsequence between two strings can be accomplished using dynamic programming to build a solution based on previously computed subsequences.
Conclusion
Dynamic programming is a crucial technique in algorithm design that enables efficient solutions to problems with overlapping subproblems and optimal substructure. By leveraging memoization or tabulation, developers can significantly improve the performance of their algorithms, making dynamic programming an essential tool in computer science and software engineering.
Backtracking
Backtracking is a general algorithmic technique that incrementally builds candidates for solutions and abandons a candidate as soon as it is determined that it cannot lead to a valid solution. It is often used for solving constraint satisfaction problems, such as puzzles, combinatorial problems, and optimization problems.
Key Concepts
-
Recursive Approach: Backtracking is typically implemented using recursion. The algorithm explores each possible option and recursively attempts to build a solution. If a solution is found, it is returned; if not, the algorithm backtracks to try the next option.
-
State Space Tree: The process of backtracking can be visualized as a tree where each node represents a state of the solution. The root node represents the initial state, and each branch represents a choice made. The leaves of the tree represent complete solutions or dead ends.
-
Pruning: One of the key advantages of backtracking is its ability to prune the search space. If a partial solution cannot lead to a valid complete solution, the algorithm can abandon that path early, thus saving time and resources.
N-Queens Problem
Place N queens on an N×N chessboard such that no two queens threaten each other.
def solve_n_queens(n):
def is_valid(board, row, col):
# Check column
for i in range(row):
if board[i][col] == 'Q':
return False
# Check diagonal (top-left)
i, j = row - 1, col - 1
while i >= 0 and j >= 0:
if board[i][j] == 'Q':
return False
i -= 1
j -= 1
# Check diagonal (top-right)
i, j = row - 1, col + 1
while i >= 0 and j < n:
if board[i][j] == 'Q':
return False
i -= 1
j += 1
return True
def backtrack(board, row):
if row == n:
result.append([''.join(row) for row in board])
return
for col in range(n):
if is_valid(board, row, col):
board[row][col] = 'Q'
backtrack(board, row + 1)
board[row][col] = '.' # Backtrack
result = []
board = [['.' for _ in range(n)] for _ in range(n)]
backtrack(board, 0)
return result
# Example usage
solutions = solve_n_queens(4)
print(f"Found {len(solutions)} solutions for 4-Queens")
for solution in solutions:
for row in solution:
print(row)
print()
Sudoku Solver
Solve a 9×9 Sudoku puzzle.
def solve_sudoku(board):
def is_valid(board, row, col, num):
# Check row
if num in board[row]:
return False
# Check column
if num in [board[i][col] for i in range(9)]:
return False
# Check 3x3 box
box_row, box_col = 3 * (row // 3), 3 * (col // 3)
for i in range(box_row, box_row + 3):
for j in range(box_col, box_col + 3):
if board[i][j] == num:
return False
return True
def backtrack():
for row in range(9):
for col in range(9):
if board[row][col] == '.':
for num in '123456789':
if is_valid(board, row, col, num):
board[row][col] = num
if backtrack():
return True
board[row][col] = '.' # Backtrack
return False
return True
backtrack()
return board
# Example usage
board = [
["5","3",".",".","7",".",".",".","."],
["6",".",".","1","9","5",".",".","."],
[".","9","8",".",".",".",".","6","."],
["8",".",".",".","6",".",".",".","3"],
["4",".",".","8",".","3",".",".","1"],
["7",".",".",".","2",".",".",".","6"],
[".","6",".",".",".",".","2","8","."],
[".",".",".","4","1","9",".",".","5"],
[".",".",".",".","8",".",".","7","9"]
]
solve_sudoku(board)
Generate Subsets
Generate all subsets (power set) of a given set.
def subsets(nums):
result = []
def backtrack(start, path):
# Add current subset to result
result.append(path[:])
# Try adding each remaining element
for i in range(start, len(nums)):
path.append(nums[i])
backtrack(i + 1, path)
path.pop() # Backtrack
backtrack(0, [])
return result
# Example usage
nums = [1, 2, 3]
print(subsets(nums))
# Output: [[], [1], [1, 2], [1, 2, 3], [1, 3], [2], [2, 3], [3]]
Generate Permutations
Generate all permutations of a given list.
def permute(nums):
result = []
def backtrack(path, remaining):
if not remaining:
result.append(path[:])
return
for i in range(len(remaining)):
# Choose
path.append(remaining[i])
# Explore
backtrack(path, remaining[:i] + remaining[i+1:])
# Unchoose (backtrack)
path.pop()
backtrack([], nums)
return result
# Alternative implementation using swap
def permute_swap(nums):
result = []
def backtrack(first):
if first == len(nums):
result.append(nums[:])
return
for i in range(first, len(nums)):
nums[first], nums[i] = nums[i], nums[first]
backtrack(first + 1)
nums[first], nums[i] = nums[i], nums[first] # Backtrack
backtrack(0)
return result
# Example usage
nums = [1, 2, 3]
print(permute(nums))
# Output: [[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]]
Combination Sum
Find all combinations that sum to a target value.
def combination_sum(candidates, target):
result = []
def backtrack(start, path, current_sum):
if current_sum == target:
result.append(path[:])
return
if current_sum > target:
return # Prune this branch
for i in range(start, len(candidates)):
path.append(candidates[i])
# Can reuse same element, so pass i (not i+1)
backtrack(i, path, current_sum + candidates[i])
path.pop() # Backtrack
backtrack(0, [], 0)
return result
# Example usage
candidates = [2, 3, 6, 7]
target = 7
print(combination_sum(candidates, target))
# Output: [[2,2,3], [7]]
Word Search
Find if a word exists in a 2D board.
def word_search(board, word):
rows, cols = len(board), len(board[0])
def backtrack(row, col, index):
# Found the word
if index == len(word):
return True
# Out of bounds or wrong character
if (row < 0 or row >= rows or
col < 0 or col >= cols or
board[row][col] != word[index]):
return False
# Mark as visited
temp = board[row][col]
board[row][col] = '#'
# Explore all directions
found = (backtrack(row + 1, col, index + 1) or
backtrack(row - 1, col, index + 1) or
backtrack(row, col + 1, index + 1) or
backtrack(row, col - 1, index + 1))
# Backtrack
board[row][col] = temp
return found
# Try starting from each cell
for row in range(rows):
for col in range(cols):
if backtrack(row, col, 0):
return True
return False
# Example usage
board = [
['A','B','C','E'],
['S','F','C','S'],
['A','D','E','E']
]
print(word_search(board, "ABCCED")) # True
print(word_search(board, "SEE")) # True
print(word_search(board, "ABCB")) # False
Palindrome Partitioning
Partition a string into all possible palindrome substrings.
def partition(s):
def is_palindrome(s, left, right):
while left < right:
if s[left] != s[right]:
return False
left += 1
right -= 1
return True
result = []
def backtrack(start, path):
if start == len(s):
result.append(path[:])
return
for end in range(start, len(s)):
if is_palindrome(s, start, end):
path.append(s[start:end+1])
backtrack(end + 1, path)
path.pop() # Backtrack
backtrack(0, [])
return result
# Example usage
s = "aab"
print(partition(s))
# Output: [["a","a","b"], ["aa","b"]]
Backtracking Template
General template for backtracking problems:
def backtrack_template(input_data):
result = []
def backtrack(state, ...):
# Base case: valid solution found
if is_valid_solution(state):
result.append(construct_solution(state))
return
# Try all possible choices
for choice in get_choices(state):
# Make choice
make_choice(state, choice)
# Recurse with updated state
backtrack(state, ...)
# Undo choice (backtrack)
undo_choice(state, choice)
# Initialize and start backtracking
initial_state = initialize()
backtrack(initial_state)
return result
Time Complexity
Most backtracking algorithms have exponential time complexity:
- Subsets: $O(2^n)$ - each element can be included or excluded
- Permutations: $O(n!)$ - n choices for first, n-1 for second, etc.
- N-Queens: $O(n!)$ - approximately, with pruning
- Sudoku: $O(9^m)$ where m is number of empty cells
Applications
Backtracking is widely used in various applications, including:
-
Puzzle Solving: Problems like Sudoku, N-Queens, and mazes can be efficiently solved using backtracking techniques.
-
Combinatorial Problems: Generating permutations, combinations, and subsets of a set can be accomplished through backtracking.
-
Graph Problems: Backtracking can be applied to find Hamiltonian paths, Eulerian paths, and other graph-related problems.
-
Constraint Satisfaction: Solving problems with constraints like graph coloring, map coloring, and scheduling.
Tips for Backtracking
- Identify the decision space: What choices can be made at each step?
- Define constraints: What makes a solution valid or invalid?
- Implement pruning: Abandon paths early when constraints are violated
- Use proper state management: Ensure state is correctly restored when backtracking
- Optimize with memoization: Cache results of repeated subproblems when possible
Conclusion
Backtracking is a powerful algorithmic technique that provides a systematic way to explore all possible solutions to a problem. By leveraging recursion and pruning, it can efficiently solve complex problems that would otherwise require exhaustive search methods.
Divide and Conquer
Divide and conquer is a fundamental algorithmic technique that involves breaking a problem down into smaller subproblems, solving each subproblem independently, and then combining their solutions to solve the original problem. This approach is particularly effective for problems that can be recursively divided into similar subproblems.
Key Concepts
-
Divide: The problem is divided into smaller subproblems that are similar to the original problem but smaller in size. This step often involves identifying a base case for the recursion.
-
Conquer: Each subproblem is solved independently, often using the same divide and conquer strategy recursively. If the subproblems are small enough, they may be solved directly.
-
Combine: The solutions to the subproblems are combined to form a solution to the original problem. This step is crucial as it integrates the results of the smaller problems into a coherent solution.
Merge Sort
Efficient sorting algorithm using divide and conquer.
def merge_sort(arr):
# Base case: array with 0 or 1 element
if len(arr) <= 1:
return arr
# Divide: split array in half
mid = len(arr) // 2
left = merge_sort(arr[:mid])
right = merge_sort(arr[mid:])
# Conquer and Combine: merge sorted halves
return merge(left, right)
def merge(left, right):
result = []
i = j = 0
# Merge while both arrays have elements
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
# Add remaining elements
result.extend(left[i:])
result.extend(right[j:])
return result
# Example usage
arr = [38, 27, 43, 3, 9, 82, 10]
sorted_arr = merge_sort(arr)
print(sorted_arr) # Output: [3, 9, 10, 27, 38, 43, 82]
Time Complexity: $O(n \log n)$ Space Complexity: $O(n)$
Quick Sort
Efficient in-place sorting algorithm.
def quick_sort(arr):
if len(arr) <= 1:
return arr
# Choose pivot (middle element)
pivot = arr[len(arr) // 2]
# Divide: partition around pivot
left = [x for x in arr if x < pivot]
middle = [x for x in arr if x == pivot]
right = [x for x in arr if x > pivot]
# Conquer and Combine
return quick_sort(left) + middle + quick_sort(right)
# In-place version
def quick_sort_inplace(arr, low, high):
if low < high:
# Partition and get pivot index
pi = partition(arr, low, high)
# Recursively sort elements before and after partition
quick_sort_inplace(arr, low, pi - 1)
quick_sort_inplace(arr, pi + 1, high)
def partition(arr, low, high):
pivot = arr[high]
i = low - 1
for j in range(low, high):
if arr[j] <= pivot:
i += 1
arr[i], arr[j] = arr[j], arr[i]
arr[i + 1], arr[high] = arr[high], arr[i + 1]
return i + 1
# Example usage
arr = [10, 7, 8, 9, 1, 5]
quick_sort_inplace(arr, 0, len(arr) - 1)
print(arr) # Output: [1, 5, 7, 8, 9, 10]
Time Complexity: $O(n \log n)$ average, $O(n^2)$ worst Space Complexity: $O(\log n)$ for recursion stack
Binary Search
Classic divide and conquer search algorithm.
def binary_search(arr, target):
left, right = 0, len(arr) - 1
while left <= right:
mid = left + (right - left) // 2
if arr[mid] == target:
return mid
elif arr[mid] < target:
left = mid + 1 # Search right half
else:
right = mid - 1 # Search left half
return -1 # Not found
# Recursive version
def binary_search_recursive(arr, target, left, right):
if left > right:
return -1
mid = left + (right - left) // 2
if arr[mid] == target:
return mid
elif arr[mid] < target:
return binary_search_recursive(arr, target, mid + 1, right)
else:
return binary_search_recursive(arr, target, left, mid - 1)
# Example usage
arr = [1, 3, 5, 7, 9, 11, 13, 15, 17]
print(binary_search(arr, 7)) # Output: 3
print(binary_search_recursive(arr, 13, 0, len(arr) - 1)) # Output: 6
Time Complexity: $O(\log n)$ Space Complexity: $O(1)$ iterative, $O(\log n)$ recursive
Maximum Subarray (Kadane's Algorithm)
Find the contiguous subarray with the largest sum.
def max_subarray_divide_conquer(arr, left, right):
# Base case: single element
if left == right:
return arr[left]
# Divide: find middle
mid = (left + right) // 2
# Conquer: recursively find max in left and right halves
left_max = max_subarray_divide_conquer(arr, left, mid)
right_max = max_subarray_divide_conquer(arr, mid + 1, right)
# Combine: find max crossing the middle
cross_max = max_crossing_sum(arr, left, mid, right)
return max(left_max, right_max, cross_max)
def max_crossing_sum(arr, left, mid, right):
# Sum from mid to left
left_sum = float('-inf')
current_sum = 0
for i in range(mid, left - 1, -1):
current_sum += arr[i]
left_sum = max(left_sum, current_sum)
# Sum from mid+1 to right
right_sum = float('-inf')
current_sum = 0
for i in range(mid + 1, right + 1):
current_sum += arr[i]
right_sum = max(right_sum, current_sum)
return left_sum + right_sum
# Example usage
arr = [-2, 1, -3, 4, -1, 2, 1, -5, 4]
max_sum = max_subarray_divide_conquer(arr, 0, len(arr) - 1)
print(f"Maximum subarray sum: {max_sum}") # Output: 6 ([4,-1,2,1])
Time Complexity: $O(n \log n)$
Count Inversions
Count how many pairs are out of order in an array.
def merge_count_inversions(arr):
if len(arr) <= 1:
return arr, 0
mid = len(arr) // 2
left, left_inv = merge_count_inversions(arr[:mid])
right, right_inv = merge_count_inversions(arr[mid:])
merged, split_inv = merge_and_count(left, right)
return merged, left_inv + right_inv + split_inv
def merge_and_count(left, right):
result = []
inversions = 0
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
inversions += len(left) - i # All remaining in left are inversions
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result, inversions
# Example usage
arr = [2, 4, 1, 3, 5]
sorted_arr, inversions = merge_count_inversions(arr)
print(f"Inversions: {inversions}") # Output: 3
Closest Pair of Points
Find the two closest points in a 2D plane.
import math
def closest_pair(points):
# Sort points by x-coordinate
px = sorted(points, key=lambda p: p[0])
# Sort points by y-coordinate
py = sorted(points, key=lambda p: p[1])
return closest_pair_recursive(px, py)
def closest_pair_recursive(px, py):
n = len(px)
# Base case: few points, use brute force
if n <= 3:
return brute_force_closest(px)
# Divide: split by vertical line
mid = n // 2
midpoint = px[mid]
pyl = [p for p in py if p[0] <= midpoint[0]]
pyr = [p for p in py if p[0] > midpoint[0]]
# Conquer: find closest in each half
dl = closest_pair_recursive(px[:mid], pyl)
dr = closest_pair_recursive(px[mid:], pyr)
# Find minimum
d = min(dl, dr)
# Combine: check points near dividing line
strip = [p for p in py if abs(p[0] - midpoint[0]) < d]
strip_min = strip_closest(strip, d)
return min(d, strip_min)
def distance(p1, p2):
return math.sqrt((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)
def brute_force_closest(points):
min_dist = float('inf')
for i in range(len(points)):
for j in range(i + 1, len(points)):
min_dist = min(min_dist, distance(points[i], points[j]))
return min_dist
def strip_closest(strip, d):
min_dist = d
for i in range(len(strip)):
for j in range(i + 1, min(i + 7, len(strip))):
min_dist = min(min_dist, distance(strip[i], strip[j]))
return min_dist
# Example usage
points = [(2, 3), (12, 30), (40, 50), (5, 1), (12, 10), (3, 4)]
min_distance = closest_pair(points)
print(f"Smallest distance: {min_distance:.2f}")
Matrix Multiplication (Strassen's Algorithm)
Faster matrix multiplication algorithm.
import numpy as np
def strassen_matrix_multiply(A, B):
n = len(A)
# Base case: 1x1 matrix
if n == 1:
return [[A[0][0] * B[0][0]]]
# Divide matrices into quadrants
mid = n // 2
A11 = [row[:mid] for row in A[:mid]]
A12 = [row[mid:] for row in A[:mid]]
A21 = [row[:mid] for row in A[mid:]]
A22 = [row[mid:] for row in A[mid:]]
B11 = [row[:mid] for row in B[:mid]]
B12 = [row[mid:] for row in B[:mid]]
B21 = [row[:mid] for row in B[mid:]]
B22 = [row[mid:] for row in B[mid:]]
# Compute 7 products (Strassen's method)
M1 = strassen_matrix_multiply(matrix_add(A11, A22), matrix_add(B11, B22))
M2 = strassen_matrix_multiply(matrix_add(A21, A22), B11)
M3 = strassen_matrix_multiply(A11, matrix_sub(B12, B22))
M4 = strassen_matrix_multiply(A22, matrix_sub(B21, B11))
M5 = strassen_matrix_multiply(matrix_add(A11, A12), B22)
M6 = strassen_matrix_multiply(matrix_sub(A21, A11), matrix_add(B11, B12))
M7 = strassen_matrix_multiply(matrix_sub(A12, A22), matrix_add(B21, B22))
# Combine
C11 = matrix_add(matrix_sub(matrix_add(M1, M4), M5), M7)
C12 = matrix_add(M3, M5)
C21 = matrix_add(M2, M4)
C22 = matrix_add(matrix_sub(matrix_add(M1, M3), M2), M6)
# Construct result
result = []
for i in range(mid):
result.append(C11[i] + C12[i])
for i in range(mid):
result.append(C21[i] + C22[i])
return result
def matrix_add(A, B):
return [[A[i][j] + B[i][j] for j in range(len(A[0]))] for i in range(len(A))]
def matrix_sub(A, B):
return [[A[i][j] - B[i][j] for j in range(len(A[0]))] for i in range(len(A))]
Time Complexity: $O(n^{2.807})$ vs $O(n^3)$ for standard multiplication
Divide and Conquer Template
def divide_and_conquer(problem):
# Base case
if is_simple(problem):
return solve_directly(problem)
# Divide
subproblems = divide(problem)
# Conquer
subsolutions = [divide_and_conquer(subproblem) for subproblem in subproblems]
# Combine
solution = combine(subsolutions)
return solution
Applications
Divide and conquer is widely used in various algorithms and applications, including:
-
Sorting Algorithms: Algorithms like Merge Sort and Quick Sort utilize the divide and conquer approach to sort elements efficiently.
-
Searching Algorithms: Binary Search is a classic example of a divide and conquer algorithm that efficiently finds an element in a sorted array.
-
Matrix Multiplication: Strassen's algorithm for matrix multiplication is another example where the divide and conquer technique is applied to reduce the complexity of the operation.
-
Computational Geometry: Problems like finding the closest pair of points or convex hull.
-
Fast Fourier Transform: FFT uses divide and conquer for efficient signal processing.
Advantages
- Efficiency: Often achieves better time complexity than brute force
- Parallelization: Subproblems can be solved independently
- Cache-friendly: Works well with memory hierarchy
- Elegant solutions: Natural recursive structure
Disadvantages
- Overhead: Recursive calls add overhead
- Space complexity: Requires stack space for recursion
- Not always optimal: Some problems have better iterative solutions
Conclusion
The divide and conquer strategy is a powerful tool in algorithm design, enabling efficient solutions to complex problems by breaking them down into manageable parts. Understanding this technique is essential for developing efficient algorithms in computer science and software engineering.
Greedy Algorithms
Greedy algorithms are a class of algorithms that make locally optimal choices at each stage with the hope of finding a global optimum. They are often used for optimization problems where a solution can be built incrementally.
Key Concepts
-
Greedy Choice Property: A global optimum can be reached by selecting a local optimum. This property is essential for the effectiveness of greedy algorithms.
-
Optimal Substructure: A problem exhibits optimal substructure if an optimal solution to the problem contains optimal solutions to its subproblems.
Activity Selection Problem
Select the maximum number of activities that don't overlap in time.
def activity_selection(activities):
# Sort by finish time
activities.sort(key=lambda x: x[1])
selected = [activities[0]]
last_finish = activities[0][1]
for start, finish in activities[1:]:
if start >= last_finish:
selected.append((start, finish))
last_finish = finish
return selected
# Example usage
activities = [(1, 4), (3, 5), (0, 6), (5, 7), (3, 9), (5, 9), (6, 10), (8, 11), (8, 12), (2, 14), (12, 16)]
result = activity_selection(activities)
print(f"Selected {len(result)} activities:")
for activity in result:
print(f" Start: {activity[0]}, Finish: {activity[1]}")
Time Complexity: $O(n \log n)$ for sorting Space Complexity: $O(n)$
Fractional Knapsack
Maximize value in knapsack by taking fractions of items.
def fractional_knapsack(items, capacity):
# Calculate value per weight and sort by it
items_with_ratio = [(value, weight, value/weight) for value, weight in items]
items_with_ratio.sort(key=lambda x: x[2], reverse=True)
total_value = 0
remaining_capacity = capacity
taken = []
for value, weight, ratio in items_with_ratio:
if remaining_capacity >= weight:
# Take full item
total_value += value
remaining_capacity -= weight
taken.append((value, weight, 1.0))
else:
# Take fraction of item
fraction = remaining_capacity / weight
total_value += value * fraction
taken.append((value, weight, fraction))
break
return total_value, taken
# Example usage
items = [(60, 10), (100, 20), (120, 30)] # (value, weight)
capacity = 50
max_value, taken = fractional_knapsack(items, capacity)
print(f"Maximum value: {max_value}")
print("Items taken:")
for value, weight, fraction in taken:
print(f" Value={value}, Weight={weight}, Fraction={fraction:.2f}")
Time Complexity: $O(n \log n)$
Coin Change (Greedy - doesn't always work!)
Make change using minimum number of coins (works for standard coin systems).
def coin_change_greedy(coins, amount):
coins.sort(reverse=True)
count = 0
result = []
for coin in coins:
while amount >= coin:
amount -= coin
count += 1
result.append(coin)
if amount > 0:
return -1, [] # Cannot make exact change
return count, result
# Example usage (US coins)
coins = [25, 10, 5, 1]
amount = 63
count, result = coin_change_greedy(coins, amount)
print(f"Minimum coins: {count}")
print(f"Coins used: {result}") # [25, 25, 10, 1, 1, 1]
Note: Greedy doesn't always give optimal solution for arbitrary coin systems. For example, with coins [1, 3, 4] and amount 6, greedy gives [4, 1, 1] (3 coins) but optimal is [3, 3] (2 coins).
Huffman Coding
Optimal prefix-free encoding for data compression.
import heapq
from collections import defaultdict
class HuffmanNode:
def __init__(self, char, freq):
self.char = char
self.freq = freq
self.left = None
self.right = None
def __lt__(self, other):
return self.freq < other.freq
def huffman_encoding(text):
# Count frequency
freq = defaultdict(int)
for char in text:
freq[char] += 1
# Create priority queue
heap = [HuffmanNode(char, f) for char, f in freq.items()]
heapq.heapify(heap)
# Build Huffman tree
while len(heap) > 1:
left = heapq.heappop(heap)
right = heapq.heappop(heap)
merged = HuffmanNode(None, left.freq + right.freq)
merged.left = left
merged.right = right
heapq.heappush(heap, merged)
# Generate codes
root = heap[0]
codes = {}
def generate_codes(node, code):
if node.char is not None:
codes[node.char] = code
return
if node.left:
generate_codes(node.left, code + '0')
if node.right:
generate_codes(node.right, code + '1')
generate_codes(root, '')
# Encode text
encoded = ''.join(codes[char] for char in text)
return encoded, codes, root
# Example usage
text = "huffman coding example"
encoded, codes, tree = huffman_encoding(text)
print("Character codes:")
for char, code in sorted(codes.items()):
print(f" '{char}': {code}")
print(f"\nOriginal size: {len(text) * 8} bits")
print(f"Encoded size: {len(encoded)} bits")
print(f"Compression ratio: {len(encoded) / (len(text) * 8):.2%}")
Time Complexity: $O(n \log n)$
Job Sequencing
Maximize profit by scheduling jobs with deadlines.
def job_sequencing(jobs):
# Sort by profit (descending)
jobs.sort(key=lambda x: x[2], reverse=True)
# Find maximum deadline
max_deadline = max(job[1] for job in jobs)
# Create slot array
slots = [-1] * max_deadline
total_profit = 0
scheduled_jobs = []
# For each job, try to schedule it
for job_id, deadline, profit in jobs:
# Find a free slot before deadline
for slot in range(min(max_deadline, deadline) - 1, -1, -1):
if slots[slot] == -1:
slots[slot] = job_id
total_profit += profit
scheduled_jobs.append((job_id, profit))
break
return total_profit, scheduled_jobs
# Example usage
# Jobs: (job_id, deadline, profit)
jobs = [
('a', 2, 100),
('b', 1, 19),
('c', 2, 27),
('d', 1, 25),
('e', 3, 15)
]
profit, scheduled = job_sequencing(jobs)
print(f"Maximum profit: {profit}")
print("Scheduled jobs:")
for job_id, profit in scheduled:
print(f" Job {job_id}: ${profit}")
Time Complexity: $O(n^2)$
Minimum Spanning Tree - Prim's Algorithm
Find minimum spanning tree of a weighted graph.
import heapq
def prim_mst(graph, start=0):
n = len(graph)
visited = set([start])
edges = [(cost, start, to) for to, cost in graph[start]]
heapq.heapify(edges)
mst = []
total_cost = 0
while edges and len(visited) < n:
cost, frm, to = heapq.heappop(edges)
if to not in visited:
visited.add(to)
mst.append((frm, to, cost))
total_cost += cost
for next_to, next_cost in graph[to]:
if next_to not in visited:
heapq.heappush(edges, (next_cost, to, next_to))
return mst, total_cost
# Example usage
# Graph as adjacency list: graph[node] = [(neighbor, weight), ...]
graph = [
[(1, 2), (3, 6)], # Node 0
[(0, 2), (2, 3), (3, 8), (4, 5)], # Node 1
[(1, 3), (4, 7)], # Node 2
[(0, 6), (1, 8)], # Node 3
[(1, 5), (2, 7)] # Node 4
]
mst, cost = prim_mst(graph)
print(f"Minimum spanning tree cost: {cost}")
print("Edges in MST:")
for frm, to, weight in mst:
print(f" {frm} -- {to} (weight: {weight})")
Time Complexity: $O(E \log V)$ with binary heap
Minimum Spanning Tree - Kruskal's Algorithm
Another MST algorithm using Union-Find.
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, x, y):
px, py = self.find(x), self.find(y)
if px == py:
return False
if self.rank[px] < self.rank[py]:
px, py = py, px
self.parent[py] = px
if self.rank[px] == self.rank[py]:
self.rank[px] += 1
return True
def kruskal_mst(n, edges):
# Sort edges by weight
edges.sort(key=lambda x: x[2])
uf = UnionFind(n)
mst = []
total_cost = 0
for u, v, weight in edges:
if uf.union(u, v):
mst.append((u, v, weight))
total_cost += weight
if len(mst) == n - 1:
break
return mst, total_cost
# Example usage
n = 5 # Number of vertices
edges = [
(0, 1, 2), (0, 3, 6), (1, 2, 3),
(1, 3, 8), (1, 4, 5), (2, 4, 7)
]
mst, cost = kruskal_mst(n, edges)
print(f"Minimum spanning tree cost: {cost}")
print("Edges in MST:")
for u, v, weight in mst:
print(f" {u} -- {v} (weight: {weight})")
Time Complexity: $O(E \log E)$ or $O(E \log V)$
Dijkstra's Shortest Path
Find shortest path from source to all other vertices.
import heapq
def dijkstra(graph, start):
n = len(graph)
dist = [float('inf')] * n
dist[start] = 0
pq = [(0, start)]
visited = set()
while pq:
d, u = heapq.heappop(pq)
if u in visited:
continue
visited.add(u)
for v, weight in graph[u]:
if dist[u] + weight < dist[v]:
dist[v] = dist[u] + weight
heapq.heappush(pq, (dist[v], v))
return dist
# Example usage
graph = [
[(1, 4), (2, 1)], # Node 0
[(3, 1)], # Node 1
[(1, 2), (3, 5)], # Node 2
[(4, 3)], # Node 3
[] # Node 4
]
distances = dijkstra(graph, 0)
print("Shortest distances from node 0:")
for i, d in enumerate(distances):
print(f" To node {i}: {d}")
Time Complexity: $O((V + E) \log V)$ with binary heap
Gas Station Problem
Find starting station to complete circular route.
def can_complete_circuit(gas, cost):
n = len(gas)
total_gas = sum(gas)
total_cost = sum(cost)
# If total gas < total cost, impossible
if total_gas < total_cost:
return -1
start = 0
tank = 0
for i in range(n):
tank += gas[i] - cost[i]
if tank < 0:
# Can't reach next station from current start
start = i + 1
tank = 0
return start
# Example usage
gas = [1, 2, 3, 4, 5]
cost = [3, 4, 5, 1, 2]
start = can_complete_circuit(gas, cost)
print(f"Start at station: {start}") # Output: 3
Time Complexity: $O(n)$
Greedy vs Dynamic Programming
Some problems can be solved by both approaches:
# Greedy (doesn't always work)
def coin_change_greedy(coins, amount):
coins.sort(reverse=True)
count = 0
for coin in coins:
count += amount // coin
amount %= coin
return count if amount == 0 else -1
# Dynamic Programming (always correct)
def coin_change_dp(coins, amount):
dp = [float('inf')] * (amount + 1)
dp[0] = 0
for coin in coins:
for i in range(coin, amount + 1):
dp[i] = min(dp[i], dp[i - coin] + 1)
return dp[amount] if dp[amount] != float('inf') else -1
When to Use Greedy
Use greedy when:
- Problem has greedy choice property
- Problem has optimal substructure
- Local optimum leads to global optimum
Common Greedy Patterns
- Sorting first: Many greedy algorithms start by sorting
- Priority queue: Use heap for best choice at each step
- Intervals: Scheduling problems often use greedy
- Graph traversal: MST, shortest path
Applications
Greedy algorithms are widely used in various applications, including:
- Network Routing: Finding the shortest path in a network (Dijkstra's algorithm)
- Resource Allocation: Distributing resources in a way that maximizes efficiency
- Job Scheduling: Scheduling jobs on machines to minimize completion time
- Data Compression: Huffman coding for optimal compression
- Minimum Spanning Trees: Network design problems
Conclusion
Greedy algorithms provide a straightforward and efficient approach to solving optimization problems. While they do not always yield the optimal solution, they are often easier to implement and can be very effective for certain types of problems.
Sorting Algorithms
Overview
Sorting arranges elements in order. Different algorithms have different trade-offs in speed, memory, and stability.
Common Algorithms
Bubble Sort
Time: $O(n^2)$ | Space: $O(1)$
def bubble_sort(arr):
n = len(arr)
for i in range(n):
for j in range(0, n - i - 1):
if arr[j] > arr[j + 1]:
arr[j], arr[j + 1] = arr[j + 1], arr[j]
return arr
Selection Sort
Time: $O(n^2)$ | Space: $O(1)$
def selection_sort(arr):
for i in range(len(arr)):
min_idx = i
for j in range(i + 1, len(arr)):
if arr[j] < arr[min_idx]:
min_idx = j
arr[i], arr[min_idx] = arr[min_idx], arr[i]
return arr
Insertion Sort
Time: $O(n^2)$ | Space: $O(1)$ | Best: $O(n)$
def insertion_sort(arr):
for i in range(1, len(arr)):
key = arr[i]
j = i - 1
while j >= 0 and arr[j] > key:
arr[j + 1] = arr[j]
j -= 1
arr[j + 1] = key
return arr
Merge Sort
Time: $O(n \log n)$ | Space: $O(n)$ | Stable: ✓
def merge_sort(arr):
if len(arr) <= 1:
return arr
mid = len(arr) // 2
left = merge_sort(arr[:mid])
right = merge_sort(arr[mid:])
return merge(left, right)
def merge(left, right):
result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] < right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result
Quick Sort
Time: $O(n \log n)$ avg, $O(n^2)$ worst | Space: $O(\log n)$
def quick_sort(arr):
if len(arr) <= 1:
return arr
pivot = arr[len(arr) // 2]
left = [x for x in arr if x < pivot]
middle = [x for x in arr if x == pivot]
right = [x for x in arr if x > pivot]
return quick_sort(left) + middle + quick_sort(right)
Heap Sort
Time: $O(n \log n)$ | Space: $O(1)$
def heap_sort(arr):
def heapify(arr, n, i):
largest = i
left = 2 * i + 1
right = 2 * i + 2
if left < n and arr[left] > arr[largest]:
largest = left
if right < n and arr[right] > arr[largest]:
largest = right
if largest != i:
arr[i], arr[largest] = arr[largest], arr[i]
heapify(arr, n, largest)
n = len(arr)
for i in range(n // 2 - 1, -1, -1):
heapify(arr, n, i)
for i in range(n - 1, 0, -1):
arr[0], arr[i] = arr[i], arr[0]
heapify(arr, i, 0)
return arr
Comparison
| Algorithm | Best | Average | Worst | Space | Stable |
|---|---|---|---|---|---|
| Bubble | $O(n)$ | $O(n^2)$ | $O(n^2)$ | $O(1)$ | ✓ |
| Selection | $O(n^2)$ | $O(n^2)$ | $O(n^2)$ | $O(1)$ | ✗ |
| Insertion | $O(n)$ | $O(n^2)$ | $O(n^2)$ | $O(1)$ | ✓ |
| Merge | $O(n \log n)$ | $O(n \log n)$ | $O(n \log n)$ | $O(n)$ | ✓ |
| Quick | $O(n \log n)$ | $O(n \log n)$ | $O(n^2)$ | $O(\log n)$ | ✗ |
| Heap | $O(n \log n)$ | $O(n \log n)$ | $O(n \log n)$ | $O(1)$ | ✗ |
When to Use
- Insertion Sort: Small arrays, nearly sorted
- Merge Sort: Need stability, external sorting
- Quick Sort: General purpose, good cache
- Heap Sort: Guaranteed $O(n \log n)$, no extra space
Python Built-in
# Best for most cases
arr.sort() # In-place, O(n log n)
sorted(arr) # Returns new list
# Custom comparator
arr.sort(key=lambda x: x['age'])
ELI10
Different sorting strategies:
- Bubble: Compare neighbors (slow)
- Quick: Pick pivot, divide and conquer (fast)
- Merge: Split in half, merge back (reliable)
Use built-in sorts unless learning!
Further Resources
Searching Algorithms
Overview
Searching algorithms help find elements in data structures. The choice depends on whether data is sorted and the size of the data.
Linear Search
Time: $O(n)$ | Space: $O(1)$ | Works on: Unsorted arrays
def linear_search(arr, target):
for i in range(len(arr)):
if arr[i] == target:
return i
return -1
When to use: Small arrays, unsorted data
Binary Search
Time: $O(\log n)$ | Space: $O(1)$ | Requires: Sorted array
def binary_search(arr, target):
left, right = 0, len(arr) - 1
while left <= right:
mid = left + (right - left) // 2
if arr[mid] == target:
return mid
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return -1 # Not found
Variations
# Find first occurrence
def find_first(arr, target):
left, right = 0, len(arr) - 1
result = -1
while left <= right:
mid = left + (right - left) // 2
if arr[mid] == target:
result = mid
right = mid - 1 # Keep searching left
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return result
# Find last occurrence
def find_last(arr, target):
left, right = 0, len(arr) - 1
result = -1
while left <= right:
mid = left + (right - left) // 2
if arr[mid] == target:
result = mid
left = mid + 1 # Keep searching right
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return result
Two Pointer Technique
Time: $O(n)$ | Space: $O(1)$
def two_sum(arr, target):
"""Find two numbers that sum to target"""
left, right = 0, len(arr) - 1
while left < right:
current_sum = arr[left] + arr[right]
if current_sum == target:
return [left, right]
elif current_sum < target:
left += 1
else:
right -= 1
return []
Jump Search
Time: O(n) | Space: $O(1)$ | Requires: Sorted array
import math
def jump_search(arr, target):
n = len(arr)
step = int(math.sqrt(n))
prev = 0
# Find block where target is present
while arr[min(step, n) - 1] < target:
prev = step
step += int(math.sqrt(n))
if prev >= n:
return -1
# Linear search in block
while arr[prev] < target:
prev += 1
if prev == min(step, n):
return -1
# Check if target found
if arr[prev] == target:
return prev
return -1
Interpolation Search
Time: $O(\log \log n)$ average, $O(n)$ worst | Requires: Sorted uniformly distributed data
def interpolation_search(arr, target):
left, right = 0, len(arr) - 1
while (left <= right and
target >= arr[left] and
target <= arr[right]):
# Estimate position
pos = left + int((right - left) / (arr[right] - arr[left]) *
(target - arr[left]))
if arr[pos] == target:
return pos
elif arr[pos] < target:
left = pos + 1
else:
right = pos - 1
return -1
Exponential Search
Time: $O(\log n)$ | Space: $O(1)$ | Requires: Sorted array
def exponential_search(arr, target):
n = len(arr)
# Find range
i = 1
while i < n and arr[i] < target:
i *= 2
# Binary search in range
left = i // 2
right = min(i, n - 1)
while left <= right:
mid = left + (right - left) // 2
if arr[mid] == target:
return mid
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return -1
Sentinel Search
Optimize linear search by eliminating boundary check:
def sentinel_search(arr, target):
n = len(arr)
last = arr[n - 1]
arr[n - 1] = target
i = 0
while arr[i] != target:
i += 1
arr[n - 1] = last # Restore
if i < n - 1 or last == target:
return i
return -1
Comparison
| Algorithm | Time (Avg) | Time (Worst) | Space | Requires Sorted |
|---|---|---|---|---|
| Linear | $O(n)$ | $O(n)$ | $O(1)$ | No |
| Binary | $O(\log n)$ | $O(\log n)$ | $O(1)$ | Yes |
| Jump | O(n) | O(n) | $O(1)$ | Yes |
| Interpolation | $O(\log \log n)$ | $O(n)$ | $O(1)$ | Yes |
| Exponential | $O(\log n)$ | $O(\log n)$ | $O(1)$ | Yes |
Key Takeaways
- Unsorted data? Use Linear Search or Hash Table
- Sorted data? Use Binary Search for $O(\log n)$
- Uniformly distributed? Try Interpolation Search
- Need flexibility? Build a Hash Table for $O(1)$ lookup
ELI10
Imagine finding a word in a dictionary:
- Linear Search: Check every word from start (slow!)
- Binary Search: Open middle, go left or right, repeat (fast!)
- Interpolation: Estimate where word should be based on first letter
Binary search is fastest for sorted data!
Further Resources
Graphs
Trees
Heaps
Tries
Raft Consensus Algorithm
Raft is a consensus algorithm designed to be easy to understand. It's used for managing a replicated log in distributed systems.
Overview
Raft ensures that a cluster of servers agrees on a sequence of values, even in the presence of failures.
Key Properties:
- Leader election
- Log replication
- Safety
- Membership changes
Server States
┌─────────┐ times out, starts election ┌───────────┐
│Follower │───────────────────────────────>│ Candidate │
└─────────┘ └───────────┘
│ │
│discovers current leader or new term │receives votes from
│ │majority of servers
│ │
│ ▼
│ ┌────────┐
└───────────────────────────────────────│ Leader │
discovers server with └────────┘
higher term
Leader Election
- Follower times out (150-300ms)
- Becomes candidate, increments term
- Votes for itself
- Requests votes from other servers
- If majority votes: becomes leader
- If another leader found: becomes follower
Log Replication
Leader receives command from client
↓
Append to local log
↓
Send AppendEntries RPCs to followers
↓
Wait for majority to acknowledge
↓
Apply to state machine
↓
Return result to client
Safety Rules
- Election Safety: At most one leader per term
- Leader Append-Only: Leader never overwrites entries
- Log Matching: If two logs contain entry with same index/term, entries are identical
- Leader Completeness: If entry committed in term, it's in leader's log
- State Machine Safety: If server applies entry at index, no other server applies different entry at that index
Example (Python-like pseudocode)
class RaftNode:
def __init__(self):
self.state = "follower"
self.current_term = 0
self.voted_for = None
self.log = []
self.commit_index = 0
def request_vote(self, term, candidate_id):
if term > self.current_term:
self.current_term = term
self.voted_for = None
if self.voted_for is None:
self.voted_for = candidate_id
return True
return False
def append_entries(self, term, leader_id, entries):
if term >= self.current_term:
self.state = "follower"
self.current_term = term
self.log.extend(entries)
return True
return False
Raft provides understandable consensus for building reliable distributed systems like etcd, Consul, and CockroachDB.
Security
Comprehensive security reference covering cryptography, authentication, and secure communications.
Cryptography
Encryption
- Symmetric encryption (AES, ChaCha20)
- Asymmetric encryption (RSA, ECC)
- Encryption modes and best practices
- Key management
Hashing
- Cryptographic hash functions
- SHA-256, SHA-3, BLAKE2
- Password hashing (bcrypt, Argon2)
- Hash-based applications
HMAC
- Hash-based Message Authentication Code
- Message integrity and authenticity
- HMAC construction and usage
- Applications in APIs and tokens
Authentication & Authorization
OAuth 2.0
- Authorization framework and grant types
- Authorization Code, Client Credentials, PKCE
- Access tokens and refresh tokens
- OpenID Connect for authentication
- Implementation best practices
JWT (JSON Web Tokens)
- Token structure (header, payload, signature)
- Signing algorithms (HS256, RS256, ES256)
- Token validation and verification
- Use cases and security considerations
- Best practices for token management
Digital Signatures
Digital Signatures
- RSA signatures
- ECDSA (Elliptic Curve Digital Signature Algorithm)
- EdDSA (Edwards-curve Digital Signature Algorithm)
- Signature verification
- Applications (code signing, documents)
Certificates
- X.509 certificates
- Certificate Authorities (CAs)
- Certificate chains and trust
- Certificate management
- Let's Encrypt and ACME protocol
Secure Communications
SSL/TLS
- TLS handshake process
- Cipher suites
- Certificate validation
- TLS 1.2 vs TLS 1.3
- Common vulnerabilities (BEAST, POODLE, Heartbleed)
- Best practices and configuration
Quick Reference
Common Algorithms
| Algorithm | Type | Key Size | Use Case |
|---|---|---|---|
| AES | Symmetric | 128/192/256-bit | General encryption |
| ChaCha20 | Symmetric | 256-bit | Mobile/embedded |
| RSA | Asymmetric | 2048/4096-bit | Key exchange, signatures |
| ECDSA | Asymmetric | 256-bit | Signatures (Bitcoin) |
| SHA-256 | Hash | N/A | Checksums, Bitcoin |
| bcrypt | Password Hash | N/A | Password storage |
| Argon2 | Password Hash | N/A | Password storage (modern) |
Security Best Practices
-
Use Modern Algorithms
- AES-256 for symmetric encryption
- RSA-2048 minimum, prefer ECC
- SHA-256 or SHA-3 for hashing
- Argon2 for password hashing
-
Key Management
- Generate strong random keys
- Rotate keys regularly
- Use HSM for critical keys
- Never hardcode secrets
-
TLS Configuration
- Use TLS 1.2 minimum (prefer 1.3)
- Disable weak cipher suites
- Enable Perfect Forward Secrecy
- Use strong certificate chains
-
Password Storage
- Never store plaintext passwords
- Use bcrypt or Argon2
- Add unique salt per password
- Use appropriate work factors
-
API Security
- Use HMAC for message integrity
- Implement rate limiting
- Use short-lived tokens
- Validate all inputs
Common Tools
# OpenSSL
openssl enc -aes-256-cbc -in file.txt -out file.enc
openssl req -new -x509 -days 365 -key key.pem -out cert.pem
# Generate keys
ssh-keygen -t ed25519
openssl genrsa -out private.key 2048
# Hashing
sha256sum file.txt
openssl dgst -sha256 file.txt
# Certificate inspection
openssl x509 -in cert.pem -text -noout
openssl s_client -connect example.com:443
Related Topics
- Network security (firewalls, VPNs)
- Application security (OWASP Top 10)
- Authentication protocols (OAuth, SAML)
- Blockchain and cryptocurrencies
Cryptographic Hash Functions
Overview
A cryptographic hash function is a mathematical algorithm that takes an input (message) of any size and produces a fixed-size output (hash digest). Hash functions are one-way functions designed to be computationally infeasible to reverse.
Key Properties
1. Deterministic
Same input always produces the same output:
hash("hello") = 2cf24dba5fb0a30e...
hash("hello") = 2cf24dba5fb0a30e... (always the same)
2. Fast Computation
Quick to compute hash for any input
3. Pre-image Resistance (One-way)
Given hash h, computationally infeasible to find message m where hash(m) = h
4. Collision Resistance
Computationally infeasible to find two different messages m1 and m2 where:
hash(m1) = hash(m2)
5. Avalanche Effect
Small change in input drastically changes output:
hash("hello") = 2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824
hash("helloX") = 9c70933a77f8d8d1eb5ba43c8f8c8b2e6f4c8e8a5b9e1b161e5c1fa7425e7304
Common Hash Functions
| Algorithm | Output Size | Status | Use Cases |
|---|---|---|---|
| MD5 | 128 bits (16 bytes) | Broken | Checksums only |
| SHA-1 | 160 bits (20 bytes) | Deprecated | Legacy systems |
| SHA-256 | 256 bits (32 bytes) | Secure | General purpose |
| SHA-512 | 512 bits (64 bytes) | Secure | High security |
| SHA-3 | Variable | Secure | Modern alternative |
| BLAKE2 | Variable | Secure | Fast, modern |
| BLAKE3 | 256 bits | Secure | Fastest, modern |
SHA-256 (Secure Hash Algorithm 256)
Algorithm Overview
SHA-256 is part of the SHA-2 family, designed by the NSA and published in 2001.
Process:
- Pad message to multiple of 512 bits
- Initialize hash values (8 x 32-bit words)
- Process message in 512-bit chunks
- Each chunk goes through 64 rounds of operations
- Produce final 256-bit hash
Using SHA-256
Python Example
import hashlib
# Hash a string
message = "Hello, World!"
hash_object = hashlib.sha256(message.encode())
hash_hex = hash_object.hexdigest()
print(f"SHA-256: {hash_hex}")
# Output: SHA-256: dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f
# Hash a file
def hash_file(filename):
sha256_hash = hashlib.sha256()
with open(filename, "rb") as f:
# Read file in chunks to handle large files
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()
file_hash = hash_file("document.pdf")
print(f"File hash: {file_hash}")
# Incremental hashing
hasher = hashlib.sha256()
hasher.update(b"Hello, ")
hasher.update(b"World!")
print(hasher.hexdigest())
# Same as hashing "Hello, World!" at once
Bash/OpenSSL Example
# Hash a string
echo -n "Hello, World!" | sha256sum
echo -n "Hello, World!" | openssl dgst -sha256
# Hash a file
sha256sum document.pdf
openssl dgst -sha256 document.pdf
# Verify file integrity
sha256sum document.pdf > checksum.txt
sha256sum -c checksum.txt
# Hash multiple files
sha256sum *.pdf > all_checksums.txt
SHA-256 Output Format
Input: "Hello, World!"
Binary (256 bits):
11011111111111010110000000100001...
Hexadecimal (64 characters):
dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f
Base64 (44 characters):
3/1gIbsr1bCvZ2KQgJ7DpTGR3YHH9wpLKGiKNiGCmG8=
SHA-3 (Keccak)
SHA-3 is based on a different construction (sponge function) than SHA-2, providing an alternative if SHA-2 is compromised.
Using SHA-3
import hashlib
message = "Hello, World!"
# SHA-3 variants
sha3_256 = hashlib.sha3_256(message.encode()).hexdigest()
sha3_512 = hashlib.sha3_512(message.encode()).hexdigest()
print(f"SHA3-256: {sha3_256}")
print(f"SHA3-512: {sha3_512}")
# SHAKE (extendable output)
shake = hashlib.shake_256(message.encode())
# Get 32 bytes of output
print(f"SHAKE256: {shake.hexdigest(32)}")
BLAKE2
Faster than SHA-2 and SHA-3, with built-in keyed hashing and salting support.
Using BLAKE2
import hashlib
message = b"Hello, World!"
# BLAKE2b (optimized for 64-bit platforms)
blake2b = hashlib.blake2b(message).hexdigest()
print(f"BLAKE2b: {blake2b}")
# BLAKE2s (optimized for 8-32 bit platforms)
blake2s = hashlib.blake2s(message).hexdigest()
print(f"BLAKE2s: {blake2s}")
# Keyed hashing (MAC)
key = b"secret-key-123"
mac = hashlib.blake2b(message, key=key).hexdigest()
print(f"BLAKE2b MAC: {mac}")
# Custom digest size
digest = hashlib.blake2b(message, digest_size=16).hexdigest()
print(f"BLAKE2b-128: {digest}")
# With salt (for password hashing)
salt = b"random-salt-16bytes!"
h = hashlib.blake2b(message, salt=salt, digest_size=32)
print(f"BLAKE2b with salt: {h.hexdigest()}")
Password Hashing
WARNING: Never use fast hashes (SHA-256, MD5) for passwords! Use specialized password hashing functions.
Why Not SHA-256 for Passwords?
# BAD - vulnerable to brute force
import hashlib
password = "password123"
hash = hashlib.sha256(password.encode()).hexdigest()
# Attacker can compute billions of SHA-256 hashes per second!
Password Hashing Requirements
- Slow: Intentionally slow to prevent brute force
- Salted: Random salt prevents rainbow tables
- Adaptive: Can increase work factor over time
- Memory-hard: Requires significant memory (for some algorithms)
bcrypt
Overview
- Based on Blowfish cipher
- Adaptive (configurable work factor)
- Automatic salt generation
- Maximum password length: 72 bytes
Using bcrypt
import bcrypt
# Hash a password
password = b"my_secure_password"
salt = bcrypt.gensalt(rounds=12) # 2^12 iterations
hashed = bcrypt.hashpw(password, salt)
print(f"Hashed: {hashed}")
# Output: b'$2b$12$KIXx8Z9...'
# Verify password
if bcrypt.checkpw(password, hashed):
print("Password matches!")
else:
print("Invalid password")
# Increase work factor over time
def needs_rehash(hashed_password, min_rounds=12):
# Extract current rounds from hash
parts = hashed_password.decode().split('$')
current_rounds = int(parts[2])
return current_rounds < min_rounds
# Complete example
def hash_password(password):
return bcrypt.hashpw(password.encode(), bcrypt.gensalt(rounds=12))
def verify_password(password, hashed):
return bcrypt.checkpw(password.encode(), hashed)
# Usage
user_password = "SuperSecret123!"
stored_hash = hash_password(user_password)
# Later, during login
login_password = "SuperSecret123!"
if verify_password(login_password, stored_hash):
print("Login successful!")
bcrypt Hash Format
$2b$12$KIXx8Z9ByF7LHfG8z.yNH.Q5GF8Z9ByF7LHfG8z.yNH.Q5GF8Z9ByF7
| | | |
| | | |
| | Salt (22 characters) Hash (31 chars)
| |
| Cost factor (2^12 iterations)
|
Algorithm identifier (2b = bcrypt)
Argon2
Overview
Winner of the Password Hashing Competition (2015). Memory-hard algorithm resistant to GPU/ASIC attacks.
Variants:
- Argon2d: Resistant to GPU attacks (not side-channel resistant)
- Argon2i: Resistant to side-channel attacks
- Argon2id: Hybrid (recommended)
Using Argon2
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError
# Create hasher with default parameters
ph = PasswordHasher()
# Hash password
password = "my_secure_password"
hashed = ph.hash(password)
print(f"Hashed: {hashed}")
# Output: $argon2id$v=19$m=65536,t=3,p=4$...
# Verify password
try:
ph.verify(hashed, password)
print("Password matches!")
except VerifyMismatchError:
print("Invalid password")
# Check if hash needs rehashing (parameters changed)
if ph.check_needs_rehash(hashed):
new_hash = ph.hash(password)
# Update in database
# Custom parameters
from argon2 import PasswordHasher
custom_ph = PasswordHasher(
time_cost=3, # Number of iterations
memory_cost=65536, # Memory usage in KiB (64 MB)
parallelism=4, # Number of parallel threads
hash_len=32, # Hash length in bytes
salt_len=16 # Salt length in bytes
)
hashed = custom_ph.hash(password)
Argon2 Hash Format
$argon2id$v=19$m=65536,t=3,p=4$c29tZXNhbHQ$hash_output_here
| | | | |
| | | | Hash output
| | | Salt (base64)
| | Parameters (memory, time, parallelism)
| Version
Variant (id, i, or d)
Argon2 Parameters Guide
# Low security (fast, for testing)
time_cost=1, memory_cost=8192, parallelism=1
# Medium security (default)
time_cost=3, memory_cost=65536, parallelism=4
# High security
time_cost=5, memory_cost=262144, parallelism=8
# Extreme security
time_cost=10, memory_cost=1048576, parallelism=16
Salting
A salt is random data added to passwords before hashing to prevent rainbow table attacks.
Without Salt (Vulnerable)
# BAD - Same password = Same hash
hash("password123") = "abc123..."
hash("password123") = "abc123..." # Attacker can precompute!
With Salt (Secure)
# GOOD - Same password = Different hashes
hash("password123" + "random_salt_1") = "xyz789..."
hash("password123" + "random_salt_2") = "def456..."
Implementing Salt
import hashlib
import os
def hash_password_with_salt(password):
# Generate random salt (16 bytes = 128 bits)
salt = os.urandom(16)
# Combine password and salt
pwdhash = hashlib.pbkdf2_hmac('sha256',
password.encode(),
salt,
100000) # iterations
# Store both salt and hash
return salt + pwdhash
def verify_password(stored_password, provided_password):
# Extract salt (first 16 bytes)
salt = stored_password[:16]
# Extract hash (remaining bytes)
stored_hash = stored_password[16:]
# Hash provided password with same salt
pwdhash = hashlib.pbkdf2_hmac('sha256',
provided_password.encode(),
salt,
100000)
return pwdhash == stored_hash
# Usage
password = "my_password"
stored = hash_password_with_salt(password)
# Verify
if verify_password(stored, "my_password"):
print("Correct password")
PBKDF2 (Password-Based Key Derivation Function 2)
Standard algorithm for deriving cryptographic keys from passwords.
import hashlib
password = b"my_password"
salt = b"random_salt_123"
# Derive key
key = hashlib.pbkdf2_hmac(
'sha256', # Hash algorithm
password, # Password
salt, # Salt
100000, # Iterations
dklen=32 # Desired key length in bytes
)
print(f"Derived key: {key.hex()}")
# For password storage
def store_password(password):
salt = os.urandom(16)
hash = hashlib.pbkdf2_hmac('sha256', password.encode(), salt, 100000)
# Store: salt + hash
return salt.hex() + '$' + hash.hex()
def check_password(password, stored):
salt_hex, hash_hex = stored.split('$')
salt = bytes.fromhex(salt_hex)
hash = bytes.fromhex(hash_hex)
new_hash = hashlib.pbkdf2_hmac('sha256', password.encode(), salt, 100000)
return new_hash == hash
Use Cases
1. File Integrity Verification
# Create checksum
sha256sum important_file.pdf > checksum.txt
# Later, verify file hasn't changed
sha256sum -c checksum.txt
2. Git Commits
Git uses SHA-1 (moving to SHA-256) to identify commits:
git log --oneline
# a1b2c3d Fix bug in authentication
3. Digital Signatures
Hash the message first, then sign the hash:
message -> hash -> encrypt with private key -> signature
4. Proof of Work (Blockchain)
import hashlib
import time
def mine_block(data, difficulty=4):
nonce = 0
target = "0" * difficulty
while True:
message = f"{data}{nonce}"
hash = hashlib.sha256(message.encode()).hexdigest()
if hash.startswith(target):
return nonce, hash
nonce += 1
# Mine a block (find hash starting with 0000)
data = "Block data here"
nonce, hash = mine_block(data, difficulty=4)
print(f"Nonce: {nonce}, Hash: {hash}")
5. Message Deduplication
import hashlib
def deduplicate_messages(messages):
seen_hashes = set()
unique_messages = []
for msg in messages:
msg_hash = hashlib.sha256(msg.encode()).hexdigest()
if msg_hash not in seen_hashes:
seen_hashes.add(msg_hash)
unique_messages.append(msg)
return unique_messages
6. Content-Addressable Storage
import hashlib
import os
class ContentAddressableStorage:
def __init__(self, storage_dir):
self.storage_dir = storage_dir
os.makedirs(storage_dir, exist_ok=True)
def store(self, data):
# Hash determines storage location
hash = hashlib.sha256(data).hexdigest()
path = os.path.join(self.storage_dir, hash)
with open(path, 'wb') as f:
f.write(data)
return hash
def retrieve(self, hash):
path = os.path.join(self.storage_dir, hash)
with open(path, 'rb') as f:
return f.read()
# Usage
cas = ContentAddressableStorage('/tmp/cas')
content = b"Important document content"
hash = cas.store(content)
retrieved = cas.retrieve(hash)
Hash Comparison
Performance Benchmark (Python)
import hashlib
import time
data = b"x" * 1000000 # 1 MB of data
algorithms = ['md5', 'sha1', 'sha256', 'sha512', 'sha3_256', 'blake2b']
for algo in algorithms:
start = time.time()
for _ in range(100):
hashlib.new(algo, data).digest()
elapsed = time.time() - start
print(f"{algo:12} {elapsed:.3f}s")
# Typical results:
# md5 0.125s (fastest, but insecure)
# sha1 0.156s (fast, but deprecated)
# blake2b 0.187s (fast and secure)
# sha256 0.234s (standard, secure)
# sha512 0.187s (fast on 64-bit, secure)
# sha3_256 0.876s (slower, secure)
Security Considerations
1. Never Use MD5 or SHA-1 for Security
# VULNERABLE - collision attacks exist
md5_hash = hashlib.md5(data).hexdigest()
sha1_hash = hashlib.sha1(data).hexdigest()
# USE INSTEAD
sha256_hash = hashlib.sha256(data).hexdigest()
2. Always Salt Passwords
# BAD
password_hash = hashlib.sha256(password.encode()).hexdigest()
# GOOD
import bcrypt
password_hash = bcrypt.hashpw(password.encode(), bcrypt.gensalt())
3. Use Appropriate Hash for Use Case
File integrity: SHA-256, BLAKE2
Password storage: bcrypt, Argon2, PBKDF2
General purpose: SHA-256, SHA-3, BLAKE2
High performance: BLAKE2, BLAKE3
Cryptographic: SHA-256, SHA-3
4. Timing Attacks
# VULNERABLE - timing attack
if hash1 == hash2:
return True
# SAFE - constant time comparison
import hmac
if hmac.compare_digest(hash1, hash2):
return True
5. Hash Length Extension Attacks
SHA-256 is vulnerable to length extension attacks. Use HMAC instead for authentication:
# VULNERABLE
auth_tag = sha256(secret + message)
# SAFE
import hmac
auth_tag = hmac.new(secret, message, hashlib.sha256).digest()
Best Practices
1. Password Hashing Checklist
# ✓ Use specialized password hash (bcrypt, Argon2)
# ✓ Use random salt (automatic in bcrypt/Argon2)
# ✓ Use sufficient work factor
# ✓ Use constant-time comparison
# ✓ Plan for rehashing when parameters change
from argon2 import PasswordHasher
import hmac
ph = PasswordHasher()
def hash_password(password):
return ph.hash(password)
def verify_password(password, hash):
try:
ph.verify(hash, password)
return True
except:
return False
2. File Integrity
# Generate checksums for all files
find . -type f -exec sha256sum {} \; > checksums.txt
# Verify later
sha256sum -c checksums.txt
3. Secure Random Salt Generation
import os
# Use cryptographically secure random
salt = os.urandom(16) # 128 bits
# DON'T use regular random module
import random
salt = random.randbytes(16) # NOT SECURE!
4. Database Schema for Passwords
CREATE TABLE users (
id INT PRIMARY KEY,
username VARCHAR(255) UNIQUE NOT NULL,
password_hash VARCHAR(255) NOT NULL, -- Store full hash string
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
password_updated_at TIMESTAMP
);
-- bcrypt example:
-- password_hash: $2b$12$KIXx8Z9ByF7LHfG8z.yNH.Q5GF8Z9ByF7...
-- Argon2 example:
-- password_hash: $argon2id$v=19$m=65536,t=3,p=4$c29tZXNhbHQ$...
Common Mistakes
1. Double Hashing
# BAD - doesn't increase security
hash1 = sha256(password)
hash2 = sha256(hash1) # No benefit!
# GOOD - use proper password hashing
hash = bcrypt.hashpw(password, bcrypt.gensalt())
2. Homemade Crypto
# BAD - creating your own hash function
def my_hash(data):
result = 0
for byte in data:
result = (result * 31 + byte) % 1000000007
return result
# GOOD - use standard algorithms
import hashlib
hash = hashlib.sha256(data).hexdigest()
3. Insufficient Work Factor
# BAD - too fast, vulnerable to brute force
hash = bcrypt.hashpw(password, bcrypt.gensalt(rounds=4)) # 2^4 = 16 iterations
# GOOD - sufficient work factor
hash = bcrypt.hashpw(password, bcrypt.gensalt(rounds=12)) # 2^12 = 4096 iterations
ELI10
A hash function is like a magic blender for data:
- You put something in: "Hello, World!"
- It blends it up: The blender scrambles everything
- You get a unique smoothie: "dffd6021bb2b..."
Special properties:
- Always the same: Same ingredients = Same smoothie
- One-way: Can't un-blend the smoothie to get ingredients back
- Tiny changes matter: "Hello, World!" vs "Hello, World?" = Completely different smoothies
- Same size: Whether you blend a strawberry or a watermelon, you always get the same size cup
For passwords, we use special slow blenders (bcrypt, Argon2):
- Regular blender: Makes 1 million smoothies per second (easy to guess passwords!)
- Password blender: Makes 10 smoothies per second (hard to guess passwords!)
Salt is like adding random spices:
- Without salt: Everyone who uses "password123" gets the same smoothie
- With salt: Everyone gets different random spices, so same password = different smoothies
Further Resources
- SHA-256 Specification (NIST)
- Password Hashing Competition
- Argon2 RFC 9106
- OWASP Password Storage Cheat Sheet
- Hash Length Extension Attacks
- bcrypt Documentation
- Argon2 Documentation
Encryption
Overview
Encryption converts readable data (plaintext) into unreadable data (ciphertext) using mathematical algorithms and keys. Only those with the correct key can decrypt it.
Types of Encryption
Symmetric Encryption
Same key encrypts and decrypts:
Plaintext + Key [Encrypt] > Ciphertext
Ciphertext + Key [Decrypt] > Plaintext
Algorithms:
- AES (Advanced Encryption Standard): Industry standard, 128/192/256-bit keys
- ChaCha20: Modern, fast, secure
- DES: Obsolete, 56-bit key (too short)
When to use: Database encryption, file encryption, internal communication
Asymmetric Encryption
Different keys for encryption/decryption:
Plaintext + Public Key [Encrypt] > Ciphertext
Ciphertext + Private Key [Decrypt] > Plaintext
Algorithms:
- RSA: Based on factoring difficulty, 2048/4096-bit keys
- ECC (Elliptic Curve): Shorter keys, same security as RSA
- Diffie-Hellman: Key exchange, not encryption
When to use: HTTPS, email encryption, digital signatures
Key Concepts
Key Size
More bits = More security but slower
AES-128: 2^128 possible keys (feasible to break with quantum)
AES-256: 2^256 possible keys (quantum-resistant)
RSA-2048: H112-bit symmetric equivalent
ECC-256: H256-bit symmetric equivalent
Modes of Operation (Symmetric)
| Mode | Use | Properties |
|---|---|---|
| ECB | L Never | Reveals patterns |
| CBC | File encryption | Needs IV, not parallel |
| CTR | Streaming | Parallelizable |
| GCM | Authenticated encryption | Authentication built-in |
Initialization Vector (IV)
Random value ensuring same plaintext produces different ciphertext:
Plaintext: "Hello"
IV1 + Key [AES-CBC] > "xK#$%"
IV2 + Key [AES-CBC] > "mN&*@" (different!)
Code Examples
Python - Symmetric (AES)
from cryptography.fernet import Fernet
# Generate key (keep secret!)
key = Fernet.generate_key() # Save this securely
cipher = Fernet(key)
# Encrypt
plaintext = b"Secret message"
ciphertext = cipher.encrypt(plaintext) # b'gAAAAABl...'
# Decrypt
plaintext = cipher.decrypt(ciphertext) # b"Secret message"
Python - Asymmetric (RSA)
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.primitives import hashes
# Generate key pair
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
)
public_key = private_key.public_key()
# Encrypt with public key
ciphertext = public_key.encrypt(
b"Secret",
padding.OAEP(hashing=hashes.SHA256())
)
# Decrypt with private key
plaintext = private_key.decrypt(
ciphertext,
padding.OAEP(hashing=hashes.SHA256())
)
JavaScript - AES Encryption
import crypto from 'crypto';
// Encrypt
const key = crypto.randomBytes(32); // 256-bit key
const iv = crypto.randomBytes(16); // 128-bit IV
const cipher = crypto.createCipheriv('aes-256-cbc', key, iv);
let ciphertext = cipher.update('Secret', 'utf8', 'hex');
ciphertext += cipher.final('hex');
// Decrypt
const decipher = crypto.createDecipheriv('aes-256-cbc', key, iv);
let plaintext = decipher.update(ciphertext, 'hex', 'utf8');
plaintext += decipher.final('utf8');
Common Algorithms Comparison
| Algorithm | Type | Speed | Security | Use Case |
|---|---|---|---|---|
| AES | Symmetric | Fast | Very good | General encryption |
| ChaCha20 | Symmetric | Very fast | Very good | Mobile/streaming |
| RSA | Asymmetric | Slow | Good | Key exchange, signatures |
| ECC | Asymmetric | Medium | Excellent | HTTPS, signatures |
| DES | Symmetric | Fast | L Broken | Legacy only |
Attacks on Encryption
Brute Force
Try all possible keys:
Defense: Use strong key (256-bit AES)
Cost: 2^256 operations (infeasible)
Side-Channel Attacks
Extract info from timing, power usage:
Defense: Constant-time operations
Weak Random Number Generator
Predictable keys:
Defense: Use cryptographically secure RNG
Quantum Computing
Threatens RSA, but not AES:
Current RSA: 2048-bit
Post-quantum: Need other algorithms
AES-256: Still secure against quantum
Best Practices
1. Choose Right Algorithm
AES-256 for symmetric
RSA-2048/ECC-256 for asymmetric
DES, MD5, SHA-1 (deprecated)
2. Secure Key Storage
# Bad: Hardcoded key
key = "supersecret123"
# Good: Load from secure storage
key = os.environ.get('ENCRYPTION_KEY')
# Or use key management service (AWS KMS, HashiCorp Vault)
3. Use Authenticated Encryption
# Use modes that verify integrity (GCM, authenticated encryption)
# Don't just encrypt without authentication
4. Random IVs
# Generate new IV for each encryption
iv = os.urandom(16) # Different each time
Key Exchange
Diffie-Hellman
Agree on shared secret over insecure channel:
Alice: chooses a, sends: g^a mod p
Bob: chooses b, sends: g^b mod p
Shared secret:
Alice: (g^b)^a mod p = g^ab mod p
Bob: (g^a)^b mod p = g^ab mod p
TLS Handshake
1. Client hello
2. Server hello + certificate (contains public key)
3. Client generates pre-master secret, encrypts with public key
4. Both derive session key from pre-master secret
5. Encrypted communication begins
ELI10
Think of encryption as a locked box:
Symmetric:
- Same key locks and unlocks
- Fast but you need to share the key somehow
- Like: "Secret code 42" for both locking and unlocking
Asymmetric:
- Public key locks, private key unlocks
- Like: Anyone can put a message in a mailbox (public), but only owner has the key (private)
Why both?:
- Asymmetric slower but solves key-sharing problem
- Use asymmetric to exchange a symmetric key
- Then use fast symmetric key for actual data
Further Resources
- Cryptography.io Python
- OWASP Encryption Cheatsheet
- 3Blue1Brown Public Key Cryptography
- AES Explained
Digital Signatures
Overview
A digital signature is a cryptographic mechanism that provides:
- Authentication: Proves who created the signature
- Integrity: Detects any changes to the signed data
- Non-repudiation: Signer cannot deny signing (unlike HMAC)
Digital signatures use asymmetric cryptography (public/private key pairs).
Digital Signature vs HMAC
| Feature | Digital Signature | HMAC |
|---|---|---|
| Keys | Public/Private key pair | Shared secret key |
| Verification | Anyone with public key | Only parties with secret |
| Non-repudiation | Yes | No |
| Performance | Slower | Faster |
| Key distribution | Public key can be shared | Secret must be protected |
| Use case | Documents, software, certificates | API auth, sessions |
How Digital Signatures Work
Signing Process
1. Hash the message
Message → Hash Function → Digest
2. Encrypt digest with private key
Digest → Private Key → Signature
3. Attach signature to message
Message + Signature → Signed Document
Verification Process
1. Hash the received message
Message → Hash Function → Digest₁
2. Decrypt signature with public key
Signature → Public Key → Digest₂
3. Compare digests
If Digest₁ == Digest₂ → Valid Signature
Visual Representation
SIGNING:
Message
|
Hash (SHA-256)
|
Digest
|
Encrypt with Private Key
|
Signature
|
Message + Signature
VERIFICATION:
Message + Signature
| |
| |
Hash (SHA-256) Decrypt with Public Key
| |
Digest₁ Digest₂
| |
+-----+-----+
|
Compare
|
Valid/Invalid
RSA Signatures
RSA Algorithm Overview
RSA uses modular arithmetic with large prime numbers:
- Key Generation: Create public (e, n) and private (d, n) keys
- Signing: signature = (hash)^d mod n
- Verification: hash = (signature)^e mod n
Generating RSA Keys
Python (cryptography library)
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
# Generate private key
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048, # 2048 or 4096 bits
)
# Generate public key
public_key = private_key.public_key()
# Save private key
pem_private = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
with open('private_key.pem', 'wb') as f:
f.write(pem_private)
# Save public key
pem_public = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
with open('public_key.pem', 'wb') as f:
f.write(pem_public)
OpenSSL (Bash)
# Generate private key (2048-bit RSA)
openssl genrsa -out private_key.pem 2048
# Generate private key with password protection
openssl genrsa -aes256 -out private_key.pem 2048
# Extract public key from private key
openssl rsa -in private_key.pem -pubout -out public_key.pem
# Generate 4096-bit key (more secure)
openssl genrsa -out private_key.pem 4096
# View key details
openssl rsa -in private_key.pem -text -noout
Signing with RSA
Python Example
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives import serialization
# Load private key
with open('private_key.pem', 'rb') as f:
private_key = serialization.load_pem_private_key(
f.read(),
password=None
)
# Message to sign
message = b"This is an important document"
# Sign message (RSA-PSS with SHA-256)
signature = private_key.sign(
message,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
print(f"Signature: {signature.hex()}")
print(f"Signature length: {len(signature)} bytes")
# Save signature
with open('signature.bin', 'wb') as f:
f.write(signature)
OpenSSL Example
# Sign a file
openssl dgst -sha256 -sign private_key.pem -out signature.bin document.txt
# Sign with different hash algorithms
openssl dgst -sha512 -sign private_key.pem -out signature.bin document.txt
# Create detached signature
openssl dgst -sha256 -sign private_key.pem -out document.sig document.pdf
Verifying RSA Signatures
Python Example
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives import serialization
from cryptography.exceptions import InvalidSignature
# Load public key
with open('public_key.pem', 'rb') as f:
public_key = serialization.load_pem_public_key(f.read())
# Load signature
with open('signature.bin', 'rb') as f:
signature = f.read()
# Message to verify
message = b"This is an important document"
# Verify signature
try:
public_key.verify(
signature,
message,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
print("✓ Signature is valid!")
except InvalidSignature:
print("✗ Invalid signature!")
# Complete example
def verify_document(public_key_path, document, signature):
with open(public_key_path, 'rb') as f:
public_key = serialization.load_pem_public_key(f.read())
try:
public_key.verify(
signature,
document,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
return True
except InvalidSignature:
return False
OpenSSL Example
# Verify signature
openssl dgst -sha256 -verify public_key.pem -signature signature.bin document.txt
# Output:
# Verified OK (if valid)
# Verification Failure (if invalid)
# Verify detached signature
openssl dgst -sha256 -verify public_key.pem -signature document.sig document.pdf
RSA Padding Schemes
PKCS#1 v1.5 (Legacy)
from cryptography.hazmat.primitives.asymmetric import padding
# Sign with PKCS#1 v1.5 (not recommended)
signature = private_key.sign(
message,
padding.PKCS1v15(),
hashes.SHA256()
)
PSS (Recommended)
# Sign with PSS (Probabilistic Signature Scheme)
signature = private_key.sign(
message,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
ECDSA (Elliptic Curve Digital Signature Algorithm)
Overview
ECDSA provides equivalent security to RSA with much smaller keys:
- RSA 2048-bit ≈ ECDSA 224-bit
- RSA 3072-bit ≈ ECDSA 256-bit
- RSA 15360-bit ≈ ECDSA 512-bit
Benefits:
- Smaller keys
- Faster signing
- Less bandwidth
- Less storage
Common Curves
| Curve | Bits | Security | Use Case |
|---|---|---|---|
| P-256 (secp256r1) | 256 | ~128-bit | General purpose, TLS |
| P-384 (secp384r1) | 384 | ~192-bit | High security |
| P-521 (secp521r1) | 521 | ~256-bit | Maximum security |
| secp256k1 | 256 | ~128-bit | Bitcoin, cryptocurrencies |
Generating ECDSA Keys
Python Example
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import serialization
# Generate private key (P-256 curve)
private_key = ec.generate_private_key(ec.SECP256R1())
# Extract public key
public_key = private_key.public_key()
# Save private key
pem_private = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
with open('ec_private_key.pem', 'wb') as f:
f.write(pem_private)
# Save public key
pem_public = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
with open('ec_public_key.pem', 'wb') as f:
f.write(pem_public)
# Different curves
# P-256 (most common)
key_p256 = ec.generate_private_key(ec.SECP256R1())
# P-384 (higher security)
key_p384 = ec.generate_private_key(ec.SECP384R1())
# P-521 (maximum security)
key_p521 = ec.generate_private_key(ec.SECP521R1())
# secp256k1 (Bitcoin)
key_secp256k1 = ec.generate_private_key(ec.SECP256K1())
OpenSSL Example
# Generate EC private key (P-256)
openssl ecparam -name prime256v1 -genkey -noout -out ec_private_key.pem
# Generate with P-384
openssl ecparam -name secp384r1 -genkey -noout -out ec_private_key.pem
# Generate with P-521
openssl ecparam -name secp521r1 -genkey -noout -out ec_private_key.pem
# Extract public key
openssl ec -in ec_private_key.pem -pubout -out ec_public_key.pem
# View key details
openssl ec -in ec_private_key.pem -text -noout
# List available curves
openssl ecparam -list_curves
Signing with ECDSA
Python Example
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import serialization
# Load private key
with open('ec_private_key.pem', 'rb') as f:
private_key = serialization.load_pem_private_key(
f.read(),
password=None
)
# Message to sign
message = b"ECDSA signature example"
# Sign message
signature = private_key.sign(
message,
ec.ECDSA(hashes.SHA256())
)
print(f"ECDSA Signature: {signature.hex()}")
print(f"Signature length: {len(signature)} bytes")
# For P-256, signature is ~64 bytes (vs ~256 bytes for RSA-2048!)
OpenSSL Example
# Sign with ECDSA
openssl dgst -sha256 -sign ec_private_key.pem -out ecdsa_signature.bin document.txt
# Verify ECDSA signature
openssl dgst -sha256 -verify ec_public_key.pem -signature ecdsa_signature.bin document.txt
Verifying ECDSA Signatures
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import serialization
from cryptography.exceptions import InvalidSignature
# Load public key
with open('ec_public_key.pem', 'rb') as f:
public_key = serialization.load_pem_public_key(f.read())
# Message and signature
message = b"ECDSA signature example"
with open('ecdsa_signature.bin', 'rb') as f:
signature = f.read()
# Verify signature
try:
public_key.verify(
signature,
message,
ec.ECDSA(hashes.SHA256())
)
print("✓ ECDSA signature is valid!")
except InvalidSignature:
print("✗ Invalid ECDSA signature!")
EdDSA (Edwards-curve Digital Signature Algorithm)
Overview
EdDSA is a modern signature scheme designed for high performance and security.
Ed25519 (most common):
- 256-bit keys
- Fast signing and verification
- Deterministic (no random number needed)
- Resistant to side-channel attacks
Generating Ed25519 Keys
from cryptography.hazmat.primitives.asymmetric import ed25519
from cryptography.hazmat.primitives import serialization
# Generate private key
private_key = ed25519.Ed25519PrivateKey.generate()
# Extract public key
public_key = private_key.public_key()
# Save private key
pem_private = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
with open('ed25519_private_key.pem', 'wb') as f:
f.write(pem_private)
# Save public key
pem_public = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
with open('ed25519_public_key.pem', 'wb') as f:
f.write(pem_public)
# Raw bytes format (32 bytes each)
private_bytes = private_key.private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption()
)
public_bytes = public_key.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw
)
print(f"Private key: {private_bytes.hex()} ({len(private_bytes)} bytes)")
print(f"Public key: {public_bytes.hex()} ({len(public_bytes)} bytes)")
Signing with Ed25519
from cryptography.hazmat.primitives.asymmetric import ed25519
# Generate key
private_key = ed25519.Ed25519PrivateKey.generate()
# Message to sign
message = b"Ed25519 is fast and secure!"
# Sign (deterministic, no hash function needed)
signature = private_key.sign(message)
print(f"Ed25519 Signature: {signature.hex()}")
print(f"Signature length: {len(signature)} bytes") # Always 64 bytes
# Verify
public_key = private_key.public_key()
try:
public_key.verify(signature, message)
print("✓ Signature valid!")
except:
print("✗ Invalid signature!")
Performance Comparison
import time
from cryptography.hazmat.primitives.asymmetric import rsa, ec, ed25519
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding
message = b"Performance test message"
# RSA
rsa_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
start = time.time()
for _ in range(1000):
sig = rsa_key.sign(message, padding.PSS(mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH), hashes.SHA256())
rsa_time = time.time() - start
# ECDSA
ec_key = ec.generate_private_key(ec.SECP256R1())
start = time.time()
for _ in range(1000):
sig = ec_key.sign(message, ec.ECDSA(hashes.SHA256()))
ecdsa_time = time.time() - start
# Ed25519
ed_key = ed25519.Ed25519PrivateKey.generate()
start = time.time()
for _ in range(1000):
sig = ed_key.sign(message)
ed25519_time = time.time() - start
print(f"RSA-2048: {rsa_time:.3f}s")
print(f"ECDSA-256: {ecdsa_time:.3f}s")
print(f"Ed25519: {ed25519_time:.3f}s")
# Typical results:
# RSA-2048: 5.234s (slowest)
# ECDSA-256: 1.876s (fast)
# Ed25519: 0.156s (fastest!)
Signature Verification
Complete Verification Example
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.primitives import hashes
from cryptography.exceptions import InvalidSignature
class DocumentSigner:
def __init__(self, private_key_path=None, public_key_path=None):
if private_key_path:
with open(private_key_path, 'rb') as f:
self.private_key = serialization.load_pem_private_key(
f.read(),
password=None
)
if public_key_path:
with open(public_key_path, 'rb') as f:
self.public_key = serialization.load_pem_public_key(f.read())
def sign_document(self, document):
signature = self.private_key.sign(
document,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
return signature
def verify_document(self, document, signature):
try:
self.public_key.verify(
signature,
document,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
return True, "Signature is valid"
except InvalidSignature:
return False, "Invalid signature - document may be tampered"
except Exception as e:
return False, f"Verification error: {str(e)}"
def sign_file(self, filepath, signature_path):
with open(filepath, 'rb') as f:
document = f.read()
signature = self.sign_document(document)
with open(signature_path, 'wb') as f:
f.write(signature)
return signature
def verify_file(self, filepath, signature_path):
with open(filepath, 'rb') as f:
document = f.read()
with open(signature_path, 'rb') as f:
signature = f.read()
return self.verify_document(document, signature)
# Usage
# Signing
signer = DocumentSigner(private_key_path='private_key.pem')
document = b"Important contract: Alice pays Bob $1000"
signature = signer.sign_document(document)
# Verification
verifier = DocumentSigner(public_key_path='public_key.pem')
is_valid, message = verifier.verify_document(document, signature)
print(f"{message}")
# File signing
signer.sign_file('contract.pdf', 'contract.pdf.sig')
is_valid, message = verifier.verify_file('contract.pdf', 'contract.pdf.sig')
print(f"Contract verification: {message}")
Code Signing
Signing Software/Scripts
# Sign a Python script
openssl dgst -sha256 -sign private_key.pem -out script.py.sig script.py
# Create a signed package
tar -czf package.tar.gz files/
openssl dgst -sha256 -sign private_key.pem -out package.tar.gz.sig package.tar.gz
# Verification script
#!/bin/bash
FILE=$1
SIG=$2
PUBKEY=$3
openssl dgst -sha256 -verify $PUBKEY -signature $SIG $FILE
if [ $? -eq 0 ]; then
echo "✓ Signature verified - safe to run"
else
echo "✗ Invalid signature - DO NOT RUN"
exit 1
fi
Python Code Signing Example
import os
import hashlib
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.primitives import hashes
class CodeSigner:
def __init__(self, private_key_path):
with open(private_key_path, 'rb') as f:
self.private_key = serialization.load_pem_private_key(
f.read(),
password=None
)
def sign_file(self, filepath):
# Read file
with open(filepath, 'rb') as f:
code = f.read()
# Generate signature
signature = self.private_key.sign(
code,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
# Save signature
sig_path = filepath + '.sig'
with open(sig_path, 'wb') as f:
f.write(signature)
print(f"✓ Signed: {filepath}")
print(f"✓ Signature: {sig_path}")
return signature
class CodeVerifier:
def __init__(self, public_key_path):
with open(public_key_path, 'rb') as f:
self.public_key = serialization.load_pem_public_key(f.read())
def verify_file(self, filepath):
sig_path = filepath + '.sig'
# Read file and signature
with open(filepath, 'rb') as f:
code = f.read()
with open(sig_path, 'rb') as f:
signature = f.read()
# Verify
try:
self.public_key.verify(
signature,
code,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256()
)
print(f"✓ Signature valid for {filepath}")
return True
except:
print(f"✗ Invalid signature for {filepath}")
return False
# Usage
signer = CodeSigner('private_key.pem')
signer.sign_file('important_script.py')
verifier = CodeVerifier('public_key.pem')
if verifier.verify_file('important_script.py'):
# Safe to execute
exec(open('important_script.py').read())
macOS Code Signing
# Sign application
codesign -s "Developer ID" MyApp.app
# Verify signature
codesign -v MyApp.app
# Deep verification
codesign -v --deep MyApp.app
# Display signature info
codesign -d -vv MyApp.app
Windows Code Signing
# Sign executable with certificate
signtool sign /f certificate.pfx /p password /t http://timestamp.server.com app.exe
# Verify signature
signtool verify /pa app.exe
Security Considerations
1. Key Size
RSA:
- Minimum: 2048 bits
- Recommended: 3072 bits
- High security: 4096 bits
ECDSA:
- Minimum: 256 bits (P-256)
- Recommended: 384 bits (P-384)
- High security: 521 bits (P-521)
Ed25519:
- Fixed: 256 bits (equivalent to ~128-bit security)
2. Hash Function
# GOOD - SHA-256 or better
signature = private_key.sign(message, padding.PSS(...), hashes.SHA256())
# BETTER - SHA-512
signature = private_key.sign(message, padding.PSS(...), hashes.SHA512())
# BAD - SHA-1 (broken!)
signature = private_key.sign(message, padding.PSS(...), hashes.SHA1())
3. Random Number Generation
# ECDSA requires good randomness
# Python's cryptography library handles this automatically
# NEVER implement your own random number generator!
# Use os.urandom() or secrets module for any manual crypto
import secrets
random_bytes = secrets.token_bytes(32)
4. Private Key Protection
# Encrypt private key with password
from cryptography.hazmat.primitives import serialization
encrypted_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.BestAvailableEncryption(b'strong-password')
)
# Load encrypted key
with open('encrypted_key.pem', 'rb') as f:
private_key = serialization.load_pem_private_key(
f.read(),
password=b'strong-password'
)
5. Signature Malleability
Some signature schemes allow multiple valid signatures for the same message.
Ed25519: NOT malleable (good!)
ECDSA: Can be malleable (use canonical form)
RSA-PSS: Probabilistic (different signatures each time, but all valid)
Best Practices
1. Use Modern Algorithms
✓ RSA-PSS (not PKCS#1 v1.5)
✓ ECDSA with P-256 or better
✓ Ed25519 (best choice for new systems)
✗ DSA (obsolete)
✗ RSA with PKCS#1 v1.5 (vulnerable)
2. Protect Private Keys
- Never commit to version control
- Use hardware security modules (HSM) for critical keys
- Use key management services (AWS KMS, Azure Key Vault)
- Encrypt keys at rest
- Limit access with proper permissions
3. Include Metadata
import json
import time
def create_signed_document(content, private_key):
metadata = {
'content': content,
'timestamp': int(time.time()),
'signer': 'John Doe',
'version': '1.0'
}
message = json.dumps(metadata, sort_keys=True).encode()
signature = private_key.sign(message, ...)
return {
'metadata': metadata,
'signature': signature.hex()
}
4. Timestamp Signatures
# Include timestamp to prevent replay attacks
import time
def sign_with_timestamp(message, private_key):
timestamp = str(int(time.time()))
data = f"{timestamp}:{message}".encode()
signature = private_key.sign(data, ...)
return {
'message': message,
'timestamp': timestamp,
'signature': signature.hex()
}
def verify_with_timestamp(signed_data, public_key, max_age=3600):
timestamp = int(signed_data['timestamp'])
current = int(time.time())
# Check if too old
if current - timestamp > max_age:
return False, "Signature expired"
# Verify signature
data = f"{signed_data['timestamp']}:{signed_data['message']}".encode()
# ... verify logic
Common Mistakes
1. Signing Hash vs Message
# WRONG - signing hash manually
hash_digest = hashlib.sha256(message).digest()
signature = private_key.sign(hash_digest, ...) # May not work!
# RIGHT - let library handle hashing
signature = private_key.sign(message, ..., hashes.SHA256())
2. Not Validating Signatures
# WRONG - trusting unsigned data
data = receive_data()
process(data) # Danger!
# RIGHT - verify signature first
data, signature = receive_data_and_signature()
if verify_signature(data, signature):
process(data)
else:
reject()
3. Exposing Private Keys
# WRONG
private_key = "-----BEGIN PRIVATE KEY-----\n..." # Hardcoded!
# RIGHT
import os
key_path = os.environ.get('PRIVATE_KEY_PATH')
with open(key_path, 'rb') as f:
private_key = load_key(f.read())
ELI10
Digital signatures are like a special seal that only you can make:
Regular signature (on paper):
- Anyone can try to copy your signature
- Hard to prove it's really yours
Digital signature:
- You have a special "stamp" that only you own (private key)
- Anyone can see your "stamp pattern" (public key)
- When you sign a document:
- You use your secret stamp to make a unique mark
- This mark is different for every document
- Others can verify:
- They use your public stamp pattern
- If it matches, they know YOU signed it
- Nobody else could have made that exact mark!
Why it's secure:
- Your secret stamp is like a lock that only you can use
- The public pattern lets others check your work
- Even if someone copies the signed document, they can't change it without your secret stamp!
Real-world example: When you download software, the developer signs it:
- ✓ You can verify it's really from them
- ✓ Nobody tampered with the software
- ✓ The developer can't deny they released it
Different from HMAC:
- HMAC: Shared secret (like both having the same password)
- Digital Signature: Private/public keys (like a lock and key everyone can see fits)
Further Resources
- RSA Cryptography Explained
- ECDSA Deep Dive
- Ed25519 Specification
- Digital Signatures Standard (DSS)
- Cryptography Engineering (Book)
- Python Cryptography Library
- OpenSSL Command Reference
X.509 Certificates and PKI
Overview
X.509 certificates are digital documents that bind public keys to identities. They enable:
- Authentication: Verify identity of servers/users
- Encryption: Establish secure connections
- Trust: Chain of trust through Certificate Authorities
X.509 Certificate Structure
Basic Components
Certificate:
├── Version (v3)
├── Serial Number (unique identifier)
├── Signature Algorithm (SHA-256 with RSA)
├── Issuer (who issued the certificate)
├── Validity Period
│ ├── Not Before (start date)
│ └── Not After (expiration date)
├── Subject (who the certificate is for)
├── Subject Public Key Info
│ ├── Algorithm (RSA, ECDSA, etc.)
│ └── Public Key (actual key data)
├── Extensions (v3)
│ ├── Key Usage
│ ├── Subject Alternative Names (SANs)
│ ├── Basic Constraints
│ └── Authority Key Identifier
└── Signature (CA's signature)
Certificate Fields
Subject: CN=example.com, O=Example Inc, C=US
CN = Common Name (domain or person name)
O = Organization
OU = Organizational Unit
C = Country
ST = State/Province
L = Locality/City
Issuer: CN=Let's Encrypt Authority, O=Let's Encrypt, C=US
(Who signed this certificate)
Validity:
Not Before: Jan 1 00:00:00 2024 GMT
Not After: Apr 1 23:59:59 2024 GMT
(Certificate valid period)
Public Key Algorithm: RSA 2048-bit
(Type and size of public key)
Signature Algorithm: SHA-256 with RSA
(How CA signed the certificate)
Visual Representation
┌─────────────────────────────────────┐
│ X.509 Certificate │
├─────────────────────────────────────┤
│ Version: 3 │
│ Serial: 04:92:7f:63:ab:02:1e... │
│ │
│ Issuer: CN=Let's Encrypt │
│ Subject: CN=example.com │
│ │
│ Valid: 2024-01-01 to 2024-04-01 │
│ │
│ Public Key: [RSA 2048-bit] │
│ 65537 │
│ 00:b8:7f:4e:91... │
│ │
│ Extensions: │
│ - Key Usage: Digital Signature │
│ - SANs: example.com, *.example.com│
│ - Basic Constraints: CA:FALSE │
│ │
│ Signature Algorithm: sha256RSA │
│ Signature: [CA's signature] │
│ 3a:7b:8c:9d... │
└─────────────────────────────────────┘
Certificate Creation
Creating a Self-Signed Certificate
OpenSSL (Bash)
# Generate private key and self-signed certificate in one command
openssl req -x509 -newkey rsa:2048 -keyout key.pem -out cert.pem -days 365 -nodes \
-subj "/C=US/ST=California/L=San Francisco/O=Example Inc/CN=example.com"
# Breakdown:
# -x509: Create self-signed certificate
# -newkey rsa:2048: Generate new 2048-bit RSA key
# -keyout: Output private key file
# -out: Output certificate file
# -days: Certificate validity period
# -nodes: Don't encrypt private key
# -subj: Certificate subject information
# View certificate details
openssl x509 -in cert.pem -text -noout
# Generate key and certificate separately
openssl genrsa -out key.pem 2048
openssl req -new -x509 -key key.pem -out cert.pem -days 365 \
-subj "/CN=example.com"
Python
from cryptography import x509
from cryptography.x509.oid import NameOID, ExtensionOID
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
import datetime
# Generate private key
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
)
# Create subject and issuer (same for self-signed)
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"),
x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Example Inc"),
x509.NameAttribute(NameOID.COMMON_NAME, "example.com"),
])
# Build certificate
cert = x509.CertificateBuilder().subject_name(
subject
).issuer_name(
issuer
).public_key(
private_key.public_key()
).serial_number(
x509.random_serial_number()
).not_valid_before(
datetime.datetime.utcnow()
).not_valid_after(
datetime.datetime.utcnow() + datetime.timedelta(days=365)
).add_extension(
x509.SubjectAlternativeName([
x509.DNSName("example.com"),
x509.DNSName("www.example.com"),
]),
critical=False,
).sign(private_key, hashes.SHA256())
# Save certificate
with open("cert.pem", "wb") as f:
f.write(cert.public_bytes(serialization.Encoding.PEM))
# Save private key
with open("key.pem", "wb") as f:
f.write(private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
))
print("Certificate created successfully!")
Creating a Certificate Signing Request (CSR)
OpenSSL
# Generate private key
openssl genrsa -out server.key 2048
# Create CSR
openssl req -new -key server.key -out server.csr \
-subj "/C=US/ST=CA/L=San Francisco/O=Example Inc/CN=example.com"
# View CSR
openssl req -in server.csr -text -noout
# Create CSR with Subject Alternative Names (using config file)
cat > san.cnf <<-END
[req]
default_bits = 2048
prompt = no
default_md = sha256
distinguished_name = dn
req_extensions = v3_req
[dn]
C=US
ST=CA
L=San Francisco
O=Example Inc
CN=example.com
[v3_req]
subjectAltName = @alt_names
[alt_names]
DNS.1 = example.com
DNS.2 = www.example.com
DNS.3 = *.example.com
END
openssl req -new -key server.key -out server.csr -config san.cnf
# Verify CSR
openssl req -in server.csr -noout -verify
Python
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
# Generate private key
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
)
# Build CSR
csr = x509.CertificateSigningRequestBuilder().subject_name(x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Example Inc"),
x509.NameAttribute(NameOID.COMMON_NAME, "example.com"),
])).add_extension(
x509.SubjectAlternativeName([
x509.DNSName("example.com"),
x509.DNSName("www.example.com"),
x509.DNSName("*.example.com"),
]),
critical=False,
).sign(private_key, hashes.SHA256())
# Save CSR
with open("server.csr", "wb") as f:
f.write(csr.public_bytes(serialization.Encoding.PEM))
print("CSR created successfully!")
Certificate Authorities (CAs)
CA Hierarchy
┌────────────────────────────┐
│ Root CA │
│ (Self-signed) │
│ Trust Anchor │
└─────────────┬──────────────┘
│
┌─────────┴─────────┐
│ │
┌───▼──────────┐ ┌────▼───────────┐
│ Intermediate │ │ Intermediate │
│ CA #1 │ │ CA #2 │
└───┬──────────┘ └────┬───────────┘
│ │
┌───▼──────┐ ┌────▼──────┐
│ End-User │ │ End-User │
│ Cert #1 │ │ Cert #2 │
└──────────┘ └───────────┘
Trust Chain
End-user certificate (example.com)
↓ Issued by
Intermediate CA certificate
↓ Issued by
Root CA certificate (in browser trust store)
✓ Trusted
Setting Up a CA
Create Root CA
# Generate Root CA private key
openssl genrsa -aes256 -out rootCA.key 4096
# Create Root CA certificate
openssl req -x509 -new -nodes -key rootCA.key -sha256 -days 3650 \
-out rootCA.crt \
-subj "/C=US/ST=CA/O=Example Inc/CN=Example Root CA"
# View Root CA certificate
openssl x509 -in rootCA.crt -text -noout
Sign Certificate with CA
# You have: server.csr (from earlier)
# You have: rootCA.key and rootCA.crt
# Create extensions configuration
echo "
[ v3_req ]
basicConstraints = CA:FALSE
keyUsage = nonRepudiation, digitalSignature, keyEncipherment
subjectAltName = @alt_names
[ alt_names ]
DNS.1 = example.com
DNS.2 = www.example.com
DNS.3 = *.example.com
" > server_ext.cnf
# Sign CSR with CA
openssl x509 -req -in server.csr \
-CA rootCA.crt -CAkey rootCA.key -CAcreateserial \
-out server.crt -days 365 -sha256 \
-extfile server_ext.cnf -extensions v3_req
# View signed certificate
openssl x509 -in server.crt -text -noout
# Verify certificate against CA
openssl verify -CAfile rootCA.crt server.crt
Python CA Implementation
from cryptography import x509
from cryptography.x509.oid import NameOID, ExtensionOID
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
import datetime
class CertificateAuthority:
def __init__(self):
# Generate CA private key
self.ca_key = rsa.generate_private_key(
public_exponent=65537,
key_size=4096,
)
# Create CA certificate
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Example Inc"),
x509.NameAttribute(NameOID.COMMON_NAME, "Example Root CA"),
])
self.ca_cert = x509.CertificateBuilder().subject_name(
subject
).issuer_name(
issuer
).public_key(
self.ca_key.public_key()
).serial_number(
x509.random_serial_number()
).not_valid_before(
datetime.datetime.utcnow()
).not_valid_after(
datetime.datetime.utcnow() + datetime.timedelta(days=3650)
).add_extension(
x509.BasicConstraints(ca=True, path_length=None),
critical=True,
).add_extension(
x509.KeyUsage(
digital_signature=True,
key_cert_sign=True,
crl_sign=True,
key_encipherment=False,
content_commitment=False,
data_encipherment=False,
key_agreement=False,
encipher_only=False,
decipher_only=False,
),
critical=True,
).sign(self.ca_key, hashes.SHA256())
def issue_certificate(self, csr, validity_days=365):
"""Issue a certificate from a CSR"""
cert = x509.CertificateBuilder().subject_name(
csr.subject
).issuer_name(
self.ca_cert.subject
).public_key(
csr.public_key()
).serial_number(
x509.random_serial_number()
).not_valid_before(
datetime.datetime.utcnow()
).not_valid_after(
datetime.datetime.utcnow() + datetime.timedelta(days=validity_days)
).add_extension(
x509.BasicConstraints(ca=False, path_length=None),
critical=True,
).add_extension(
x509.KeyUsage(
digital_signature=True,
key_encipherment=True,
key_cert_sign=False,
crl_sign=False,
content_commitment=False,
data_encipherment=False,
key_agreement=False,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
# Copy extensions from CSR
for extension in csr.extensions:
cert = cert.add_extension(extension.value, extension.critical)
# Sign with CA key
return cert.sign(self.ca_key, hashes.SHA256())
def save_ca_cert(self, filename):
with open(filename, "wb") as f:
f.write(self.ca_cert.public_bytes(serialization.Encoding.PEM))
def save_ca_key(self, filename, password=None):
encryption = serialization.NoEncryption()
if password:
encryption = serialization.BestAvailableEncryption(password)
with open(filename, "wb") as f:
f.write(self.ca_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=encryption
))
# Usage
ca = CertificateAuthority()
ca.save_ca_cert("ca.crt")
ca.save_ca_key("ca.key", password=b"secure-password")
# Load and sign a CSR
with open("server.csr", "rb") as f:
csr = x509.load_pem_x509_csr(f.read())
cert = ca.issue_certificate(csr, validity_days=365)
with open("server.crt", "wb") as f:
f.write(cert.public_bytes(serialization.Encoding.PEM))
print("Certificate issued successfully!")
Certificate Chains
Understanding Certificate Chains
┌─────────────────────────────────┐
│ Server Certificate │
│ Subject: CN=example.com │
│ Issuer: CN=Intermediate CA │
│ [Public Key] │
│ [Signature by Intermediate] │
└────────────┬────────────────────┘
│ Verified by
┌────────────▼────────────────────┐
│ Intermediate Certificate │
│ Subject: CN=Intermediate CA │
│ Issuer: CN=Root CA │
│ [Public Key] │
│ [Signature by Root] │
└────────────┬────────────────────┘
│ Verified by
┌────────────▼────────────────────┐
│ Root Certificate │
│ Subject: CN=Root CA │
│ Issuer: CN=Root CA (self) │
│ [Public Key] │
│ [Self Signature] │
│ ✓ In Trust Store │
└─────────────────────────────────┘
Building Certificate Chain
# Create chain file (server cert + intermediate cert)
cat server.crt intermediate.crt > fullchain.pem
# Or with root CA (not usually needed)
cat server.crt intermediate.crt rootCA.crt > fullchain.pem
# Verify chain
openssl verify -CAfile rootCA.crt -untrusted intermediate.crt server.crt
# Display certificate chain
openssl s_client -connect example.com:443 -showcerts
Verifying Certificate Chain in Python
from cryptography import x509
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives import hashes
from cryptography.exceptions import InvalidSignature
def verify_certificate_chain(cert_chain):
"""
Verify a certificate chain
cert_chain: list of certificates [leaf, intermediate, ..., root]
"""
for i in range(len(cert_chain) - 1):
cert = cert_chain[i]
issuer_cert = cert_chain[i + 1]
# Verify issuer name matches
if cert.issuer != issuer_cert.subject:
return False, f"Issuer mismatch at level {i}"
# Verify signature
try:
issuer_public_key = issuer_cert.public_key()
issuer_public_key.verify(
cert.signature,
cert.tbs_certificate_bytes,
padding.PKCS1v15(),
cert.signature_hash_algorithm,
)
except InvalidSignature:
return False, f"Invalid signature at level {i}"
# Verify validity period
import datetime
now = datetime.datetime.utcnow()
if now < cert.not_valid_before or now > cert.not_valid_after:
return False, f"Certificate expired or not yet valid at level {i}"
return True, "Chain verified successfully"
# Load certificates
certs = []
for cert_file in ['server.crt', 'intermediate.crt', 'root.crt']:
with open(cert_file, 'rb') as f:
cert = x509.load_pem_x509_certificate(f.read())
certs.append(cert)
# Verify chain
is_valid, message = verify_certificate_chain(certs)
print(message)
Let's Encrypt
Overview
Let's Encrypt is a free, automated Certificate Authority providing:
- Free SSL/TLS certificates
- 90-day validity (encourages automation)
- Domain Validation (DV) only
- Automated renewal
ACME Protocol
1. Client requests certificate for example.com
2. Let's Encrypt challenges ownership:
- HTTP-01: Place file at http://example.com/.well-known/acme-challenge/
- DNS-01: Add TXT record to _acme-challenge.example.com
- TLS-ALPN-01: Configure TLS server with special certificate
3. Let's Encrypt verifies challenge
4. If successful, issues certificate
5. Client installs certificate
6. Automated renewal before 90-day expiration
Using Certbot
# Install certbot
sudo apt-get install certbot
# Obtain certificate (standalone)
sudo certbot certonly --standalone -d example.com -d www.example.com
# Obtain certificate (webroot - site already running)
sudo certbot certonly --webroot -w /var/www/html -d example.com
# Obtain certificate (DNS challenge)
sudo certbot certonly --manual --preferred-challenges dns -d example.com
# Obtain certificate (with automatic nginx configuration)
sudo certbot --nginx -d example.com -d www.example.com
# Obtain certificate (with automatic apache configuration)
sudo certbot --apache -d example.com
# List certificates
sudo certbot certificates
# Renew certificates (dry run)
sudo certbot renew --dry-run
# Renew certificates
sudo certbot renew
# Revoke certificate
sudo certbot revoke --cert-path /etc/letsencrypt/live/example.com/cert.pem
# Delete certificate
sudo certbot delete --cert-name example.com
Automated Renewal
# Add to crontab (check renewal twice daily)
0 0,12 * * * certbot renew --quiet
# Systemd timer (if using systemd)
sudo systemctl enable certbot-renew.timer
sudo systemctl start certbot-renew.timer
# Test renewal
sudo certbot renew --dry-run
Using acme.sh (Alternative)
# Install acme.sh
curl https://get.acme.sh | sh
# Issue certificate (HTTP validation)
acme.sh --issue -d example.com -w /var/www/html
# Issue certificate (DNS validation with Cloudflare)
export CF_Key="your-cloudflare-api-key"
export CF_Email="your@email.com"
acme.sh --issue --dns dns_cf -d example.com -d *.example.com
# Install certificate
acme.sh --install-cert -d example.com \
--key-file /etc/nginx/ssl/example.com.key \
--fullchain-file /etc/nginx/ssl/example.com.crt \
--reloadcmd "systemctl reload nginx"
# Renew all certificates
acme.sh --renew-all
# Force renew
acme.sh --renew -d example.com --force
Certificate Management
Certificate Inspection
# View certificate details
openssl x509 -in cert.pem -text -noout
# View certificate dates
openssl x509 -in cert.pem -noout -dates
# View certificate subject
openssl x509 -in cert.pem -noout -subject
# View certificate issuer
openssl x509 -in cert.pem -noout -issuer
# View certificate fingerprint
openssl x509 -in cert.pem -noout -fingerprint -sha256
# Check certificate and key match
openssl x509 -noout -modulus -in cert.pem | openssl md5
openssl rsa -noout -modulus -in key.pem | openssl md5
# If md5 hashes match, cert and key are paired
# View certificate from server
openssl s_client -connect example.com:443 -showcerts
# Check certificate expiration
echo | openssl s_client -connect example.com:443 2>/dev/null | \
openssl x509 -noout -dates
Python Certificate Tools
from cryptography import x509
from cryptography.hazmat.primitives import serialization
import datetime
def inspect_certificate(cert_path):
with open(cert_path, 'rb') as f:
cert = x509.load_pem_x509_certificate(f.read())
print("Certificate Information:")
print(f"Subject: {cert.subject.rfc4514_string()}")
print(f"Issuer: {cert.issuer.rfc4514_string()}")
print(f"Serial Number: {cert.serial_number}")
print(f"Not Valid Before: {cert.not_valid_before}")
print(f"Not Valid After: {cert.not_valid_after}")
print(f"Signature Algorithm: {cert.signature_algorithm_oid._name}")
# Check if expired
now = datetime.datetime.utcnow()
days_until_expiry = (cert.not_valid_after - now).days
if now > cert.not_valid_after:
print("⚠ Certificate EXPIRED!")
elif days_until_expiry < 30:
print(f"⚠ Certificate expires soon ({days_until_expiry} days)")
else:
print(f"✓ Certificate valid ({days_until_expiry} days remaining)")
# Subject Alternative Names
try:
san_ext = cert.extensions.get_extension_for_oid(
x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME
)
print(f"SANs: {', '.join([dns.value for dns in san_ext.value])}")
except x509.ExtensionNotFound:
print("No SANs found")
return cert
# Usage
cert = inspect_certificate('cert.pem')
Certificate Monitoring
import ssl
import socket
from datetime import datetime
def check_certificate_expiry(hostname, port=443):
"""Check SSL certificate expiration"""
context = ssl.create_default_context()
with socket.create_connection((hostname, port)) as sock:
with context.wrap_socket(sock, server_hostname=hostname) as ssock:
cert = ssock.getpeercert()
# Parse expiration date
expires = datetime.strptime(
cert['notAfter'],
'%b %d %H:%M:%S %Y %GMT'
)
days_remaining = (expires - datetime.now()).days
print(f"Certificate for {hostname}:")
print(f" Subject: {dict(x[0] for x in cert['subject'])['commonName']}")
print(f" Issuer: {dict(x[0] for x in cert['issuer'])['commonName']}")
print(f" Expires: {expires}")
print(f" Days remaining: {days_remaining}")
if days_remaining < 0:
print(" ⚠ EXPIRED!")
elif days_remaining < 30:
print(" ⚠ Expiring soon!")
else:
print(" ✓ Valid")
return days_remaining
# Check multiple sites
sites = ['google.com', 'github.com', 'example.com']
for site in sites:
try:
check_certificate_expiry(site)
print()
except Exception as e:
print(f"Error checking {site}: {e}\n")
Certificate Renewal Strategy
#!/bin/bash
# certificate-renewal.sh
# Check certificate expiration
check_expiry() {
local domain=$1
local days_until_expiry=$(echo | openssl s_client -connect $domain:443 2>/dev/null | \
openssl x509 -noout -checkend 2592000) # 30 days
if [ $? -eq 0 ]; then
echo "$domain: Certificate valid for at least 30 days"
return 0
else
echo "$domain: Certificate expires within 30 days!"
return 1
fi
}
# Renew if needed
renew_certificate() {
local domain=$1
if ! check_expiry $domain; then
echo "Renewing certificate for $domain..."
certbot renew --cert-name $domain
if [ $? -eq 0 ]; then
echo "Certificate renewed successfully"
systemctl reload nginx
else
echo "Certificate renewal failed!"
# Send alert
fi
fi
}
# Check all domains
for domain in example.com api.example.com www.example.com; do
renew_certificate $domain
done
Certificate Revocation
Certificate Revocation Lists (CRL)
# Download CRL
wget http://crl.example.com/example.crl
# View CRL
openssl crl -in example.crl -text -noout
# Check if certificate is revoked
openssl verify -crl_check -CRLfile example.crl -CAfile ca.crt cert.pem
Online Certificate Status Protocol (OCSP)
# Get OCSP responder URL from certificate
openssl x509 -in cert.pem -noout -ocsp_uri
# Check certificate status via OCSP
openssl ocsp -issuer ca.crt -cert cert.pem \
-url http://ocsp.example.com \
-resp_text
# OCSP stapling check
openssl s_client -connect example.com:443 -status
Revoking Certificate
# Revoke with certbot
sudo certbot revoke --cert-path /etc/letsencrypt/live/example.com/cert.pem
# Revoke with reason
sudo certbot revoke --cert-path cert.pem --reason keycompromise
# Revoke with custom CA
openssl ca -config ca.conf -revoke cert.pem -keyfile ca.key -cert ca.crt
# Generate CRL
openssl ca -config ca.conf -gencrl -out crl.pem
Security Considerations
1. Key Size
RSA:
Minimum: 2048 bits
Recommended: 3072-4096 bits
ECDSA:
Recommended: P-256 (256-bit)
High security: P-384 (384-bit)
Ed25519:
Fixed: 256-bit (recommended for new deployments)
2. Certificate Validity Period
Modern best practices:
- Maximum: 398 days (13 months) - enforced by browsers
- Recommended: 90 days (Let's Encrypt default)
- Automated renewal: Essential for short validity
Historical:
- Before 2020: Up to 2-3 years
- 2020: 398 days maximum
- Trend: Shorter validity periods
3. Subject Alternative Names (SANs)
# Include all domain variants
subjectAltName = DNS:example.com,DNS:www.example.com,DNS:*.example.com
# Don't rely on Common Name (CN) - deprecated
# Always use SANs
4. Certificate Pinning
import ssl
import hashlib
import socket
def verify_certificate_pinning(hostname, expected_fingerprints):
"""Verify certificate matches expected fingerprint"""
context = ssl.create_default_context()
with socket.create_connection((hostname, 443)) as sock:
with context.wrap_socket(sock, server_hostname=hostname) as ssock:
cert_der = ssock.getpeercert(binary_form=True)
fingerprint = hashlib.sha256(cert_der).hexdigest()
if fingerprint in expected_fingerprints:
print(f"✓ Certificate pinning verified")
return True
else:
print(f"✗ Certificate pinning failed!")
print(f" Expected: {expected_fingerprints}")
print(f" Got: {fingerprint}")
return False
# Usage
expected_pins = [
'a1b2c3d4e5f6...', # Primary certificate
'9a8b7c6d5e4f...', # Backup certificate
]
verify_certificate_pinning('example.com', expected_pins)
Best Practices
1. Automate Certificate Management
✓ Use Let's Encrypt for free certificates
✓ Automate renewal (certbot, acme.sh)
✓ Monitor expiration dates
✓ Test renewal process regularly
✓ Use short validity periods (90 days)
2. Secure Private Keys
# Restrict permissions
chmod 600 private.key
# Use hardware security modules (HSM) for critical keys
# Use encrypted private keys
openssl rsa -aes256 -in private.key -out private_encrypted.key
# Never commit to version control
echo "*.key" >> .gitignore
echo "*.pem" >> .gitignore
3. Use Strong Cryptography
✓ RSA 2048-bit minimum (prefer 3072+)
✓ ECDSA P-256 or better
✓ SHA-256 or SHA-512 for signatures
✗ Avoid MD5, SHA-1
✗ Avoid RSA <2048 bits
4. Implement Certificate Transparency
# Check if certificate is in CT logs
curl https://crt.sh/?q=example.com
# Monitor for unauthorized certificates
# Use tools like certstream, certificate-transparency-go
Common Mistakes
1. Expired Certificates
Problem: Certificate expires unexpectedly
Solution: Automate monitoring and renewal
2. Missing Intermediate Certificates
Problem: Browser shows untrusted certificate
Solution: Include full chain (server + intermediate certs)
# Correct chain order
cat server.crt intermediate.crt > fullchain.pem
3. Certificate Name Mismatch
Problem: Certificate for wrong domain
Solution: Use proper SANs
# Include all domains
subjectAltName = DNS:example.com,DNS:www.example.com
4. Insecure Private Key
Problem: Private key readable by all users
Solution: Restrict permissions
chmod 600 private.key
chown root:root private.key
ELI10
Certificates are like ID cards for websites:
Without certificates:
- You visit "bank.com"
- How do you know it's really your bank?
- Attackers could pretend to be your bank!
With certificates:
-
Website has ID card (certificate)
- Says: "I'm bank.com"
- Has a special seal (signature)
-
Trusted Authority (CA like Let's Encrypt)
- Like a government issuing passports
- Checks: "Yes, you really own bank.com"
- Adds their official seal
-
Your browser checks:
- Is the ID card real? ✓
- Is it expired? ✓
- Does it match the website name? ✓
- Is the seal from a trusted authority? ✓
-
Chain of Trust:
Browser trusts → Root CA Root CA trusts → Intermediate CA Intermediate CA trusts → Website Certificate Therefore, Browser trusts → Website!
Let's Encrypt made it:
- Free (used to cost $$$)
- Automatic (renews itself)
- Easy (simple commands)
Real-world analogy:
- Certificate = Passport
- CA = Government passport office
- Browser = Border control checking passports
- Expiration date = Passport validity
- Renewal = Getting new passport before expiry
Further Resources
- Let's Encrypt Documentation
- X.509 Certificate Format (RFC 5280)
- ACME Protocol (RFC 8555)
- Certificate Transparency
- SSL Labs Server Test
- Certbot Documentation
- OpenSSL Cookbook
- Public Key Infrastructure (PKI) Guide
SSL/TLS (Secure Sockets Layer / Transport Layer Security)
Overview
TLS (Transport Layer Security) is a cryptographic protocol that provides secure communication over networks. SSL is the predecessor to TLS (now deprecated).
Key Features:
- Confidentiality: Encryption prevents eavesdropping
- Integrity: Detects message tampering
- Authentication: Verifies server (and optionally client) identity
SSL/TLS History
| Version | Year | Status | Notes |
|---|---|---|---|
| SSL 1.0 | - | Never released | Internal Netscape protocol |
| SSL 2.0 | 1995 | Deprecated | Serious security flaws |
| SSL 3.0 | 1996 | Deprecated | POODLE attack (2014) |
| TLS 1.0 | 1999 | Deprecated | Similar to SSL 3.0 |
| TLS 1.1 | 2006 | Deprecated | Minor improvements |
| TLS 1.2 | 2008 | Secure | Currently widely used |
| TLS 1.3 | 2018 | Secure | Modern, fastest, most secure |
TLS Handshake (TLS 1.2)
Full Handshake Process
Client Server
1. ClientHello -------->
- TLS version
- Cipher suites
- Random bytes
- Extensions
<-------- 2. ServerHello
- Chosen cipher suite
- Random bytes
- Session ID
3. Certificate
- Server certificate chain
4. ServerKeyExchange
- Key exchange parameters
5. ServerHelloDone
6. ClientKeyExchange -------->
- Pre-master secret (encrypted)
7. ChangeCipherSpec -------->
- Switch to encrypted communication
8. Finished -------->
- Verification message (encrypted)
<-------- 9. ChangeCipherSpec
<-------- 10. Finished
11. Encrypted Application Data <---> Encrypted Application Data
Detailed Steps
1. ClientHello
Client → Server:
TLS Version: 1.2
Cipher Suites:
- TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
- TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
- TLS_RSA_WITH_AES_128_CBC_SHA256
Random: [28 bytes client random]
Session ID: [empty for new session]
Extensions:
- server_name: example.com
- supported_groups: P-256, P-384
- signature_algorithms: RSA-PSS-SHA256, ECDSA-SHA256
2. ServerHello
Server → Client:
TLS Version: 1.2
Cipher Suite: TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
Random: [28 bytes server random]
Session ID: [32 bytes for session resumption]
Extensions:
- renegotiation_info
- extended_master_secret
3. Certificate
Server → Client:
Certificate Chain:
1. Server certificate (example.com)
2. Intermediate CA certificate
[Root CA not sent - client has it]
4. ServerKeyExchange (for ECDHE)
Server → Client:
Curve: P-256
Public Key: [server's ephemeral ECDH public key]
Signature: [signed with server's private key]
5. ClientKeyExchange
Client → Server:
Pre-master Secret:
[Encrypted with server's public key (RSA) OR
Client's ephemeral ECDH public key (ECDHE)]
6. Master Secret Derivation
Both compute:
Master Secret = PRF(
pre-master secret,
"master secret",
ClientHello.random + ServerHello.random
)
Then derive:
- Client write MAC key
- Server write MAC key
- Client write encryption key
- Server write encryption key
- Client write IV
- Server write IV
Visual TLS 1.2 Handshake
┌────────┐ ┌────────┐
│ Client │ │ Server │
└───┬────┘ └───┬────┘
│ │
│ ClientHello │
│ (ciphers, random, SNI) │
├───────────────────────────────────>│
│ │
│ ServerHello │
│ (chosen cipher, random)│
│<───────────────────────────────────┤
│ │
│ Certificate │
│ (server cert chain) │
│<───────────────────────────────────┤
│ │
│ ServerKeyExchange │
│ (DH params, signature) │
│<───────────────────────────────────┤
│ │
│ ServerHelloDone │
│<───────────────────────────────────┤
│ │
│ ClientKeyExchange │
│ (pre-master secret) │
├───────────────────────────────────>│
│ │
│ ChangeCipherSpec │
├───────────────────────────────────>│
│ │
│ Finished (encrypted) │
├───────────────────────────────────>│
│ │
│ ChangeCipherSpec │
│<───────────────────────────────────┤
│ │
│ Finished (encrypted) │
│<───────────────────────────────────┤
│ │
│ Application Data (encrypted) │
│<──────────────────────────────────>│
│ │
TLS 1.3 Handshake
TLS 1.3 is faster - only 1 round-trip (vs 2 in TLS 1.2):
Client Server
1. ClientHello -------->
- Key share (DH)
- Supported versions
- Cipher suites
<-------- 2. ServerHello
- Key share (DH)
- Chosen cipher
{Certificate}*
{CertificateVerify}*
{Finished}
[Application Data]
{Finished} -------->
[Application Data] <-------> [Application Data]
* Encrypted with handshake traffic keys
[] Encrypted with application traffic keys
Key Differences TLS 1.3 vs 1.2
| Feature | TLS 1.2 | TLS 1.3 |
|---|---|---|
| Round trips | 2-RTT | 1-RTT |
| 0-RTT mode | No | Yes (with risks) |
| Cipher suites | Many (weak ones) | Only 5 strong ones |
| Key exchange | RSA, DHE, ECDHE | Only (EC)DHE |
| Encryption | After handshake | Most of handshake encrypted |
| Performance | Slower | Faster |
| Security | Vulnerable configs | Secure by default |
TLS 1.3 Improvements
- Faster handshake (1-RTT instead of 2-RTT)
- 0-RTT mode (resume with no round trips)
- Removed weak crypto (RC4, MD5, SHA-1, RSA key exchange)
- Forward secrecy (mandatory ECDHE)
- Encrypted handshake (server certificate encrypted)
- Simplified cipher suites
Cipher Suites
Cipher Suite Format (TLS 1.2)
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
| | | | | | |
| | | | | | +-- MAC algorithm (SHA-256)
| | | | | +------ AEAD mode (GCM)
| | | | +-------------- Encryption (AES-128)
| | | +------------------- "WITH"
| | +----------------------- Authentication (RSA)
| +----------------------------- Key exchange (ECDHE)
+--------------------------------- Protocol (TLS)
Common Cipher Suites (TLS 1.2)
Strong & Recommended
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
Weak (Avoid)
TLS_RSA_WITH_RC4_128_SHA # RC4 broken
TLS_RSA_WITH_3DES_EDE_CBC_SHA # 3DES weak
TLS_RSA_WITH_AES_128_CBC_SHA # CBC mode, no forward secrecy
TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 # CBC mode
Cipher Suite Components
1. Key Exchange
RSA: No forward secrecy (deprecated)
DHE: Diffie-Hellman Ephemeral (slow)
ECDHE: Elliptic Curve DHE (fast, forward secrecy) ✓
2. Authentication
RSA: RSA certificate
ECDSA: Elliptic Curve certificate (smaller, faster)
DSA: Digital Signature Algorithm (obsolete)
3. Encryption
AES-128-GCM: Fast, secure, hardware accelerated ✓
AES-256-GCM: Higher security ✓
ChaCha20-Poly1305: Fast on mobile (no AES hardware) ✓
AES-CBC: Vulnerable to padding oracles (avoid)
3DES: Obsolete (avoid)
RC4: Broken (never use)
4. MAC (Message Authentication Code)
SHA-256: Secure ✓
SHA-384: Secure ✓
SHA-1: Weak (avoid)
MD5: Broken (never use)
Note: AEAD modes (GCM, ChaCha20-Poly1305) don't need separate MAC
TLS 1.3 Cipher Suites (Simplified)
TLS_AES_128_GCM_SHA256
TLS_AES_256_GCM_SHA384
TLS_CHACHA20_POLY1305_SHA256
TLS_AES_128_CCM_SHA256
TLS_AES_128_CCM_8_SHA256
Only 5 cipher suites! Key exchange and auth determined separately.
Configuring TLS
Nginx Configuration
server {
listen 443 ssl http2;
server_name example.com;
# Certificates
ssl_certificate /etc/letsencrypt/live/example.com/fullchain.pem;
ssl_certificate_key /etc/letsencrypt/live/example.com/privkey.pem;
# TLS versions
ssl_protocols TLSv1.2 TLSv1.3;
# Cipher suites (TLS 1.2)
ssl_ciphers 'ECDHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-CHACHA20-POLY1305';
ssl_prefer_server_ciphers on;
# DH parameters (for DHE cipher suites)
ssl_dhparam /etc/nginx/dhparam.pem;
# OCSP Stapling
ssl_stapling on;
ssl_stapling_verify on;
ssl_trusted_certificate /etc/letsencrypt/live/example.com/chain.pem;
# Session tickets
ssl_session_timeout 1d;
ssl_session_cache shared:SSL:50m;
ssl_session_tickets off;
# HSTS (HTTP Strict Transport Security)
add_header Strict-Transport-Security "max-age=31536000; includeSubDomains; preload" always;
location / {
root /var/www/html;
}
}
# Redirect HTTP to HTTPS
server {
listen 80;
server_name example.com;
return 301 https://$server_name$request_uri;
}
Apache Configuration
<VirtualHost *:443>
ServerName example.com
# Certificates
SSLCertificateFile /etc/letsencrypt/live/example.com/cert.pem
SSLCertificateKeyFile /etc/letsencrypt/live/example.com/privkey.pem
SSLCertificateChainFile /etc/letsencrypt/live/example.com/chain.pem
# TLS versions
SSLProtocol -all +TLSv1.2 +TLSv1.3
# Cipher suites
SSLCipherSuite ECDHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-CHACHA20-POLY1305
SSLHonorCipherOrder on
# OCSP Stapling
SSLUseStapling on
SSLStaplingCache "shmcb:logs/ssl_stapling(32768)"
# HSTS
Header always set Strict-Transport-Security "max-age=31536000; includeSubDomains; preload"
DocumentRoot /var/www/html
</VirtualHost>
# Redirect HTTP to HTTPS
<VirtualHost *:80>
ServerName example.com
Redirect permanent / https://example.com/
</VirtualHost>
Python HTTPS Server
import http.server
import ssl
# Simple HTTPS server
server_address = ('0.0.0.0', 4443)
httpd = http.server.HTTPServer(server_address, http.server.SimpleHTTPRequestHandler)
# Create SSL context
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain('cert.pem', 'key.pem')
# Optional: Configure cipher suites
context.set_ciphers('ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20:!aNULL:!MD5:!DSS')
# Wrap socket with TLS
httpd.socket = context.wrap_socket(httpd.socket, server_side=True)
print("Server running on https://localhost:4443")
httpd.serve_forever()
Python Client with TLS
import ssl
import socket
def https_request(hostname, path='/'):
# Create SSL context
context = ssl.create_default_context()
# Optional: Verify certificate
# context.check_hostname = True
# context.verify_mode = ssl.CERT_REQUIRED
# Optional: Pin certificate
# context.load_verify_locations('ca-bundle.crt')
# Connect
with socket.create_connection((hostname, 443)) as sock:
with context.wrap_socket(sock, server_hostname=hostname) as ssock:
# Send HTTP request
request = f"GET {path} HTTP/1.1\r\nHost: {hostname}\r\nConnection: close\r\n\r\n"
ssock.send(request.encode())
# Receive response
response = b''
while True:
data = ssock.recv(4096)
if not data:
break
response += data
return response.decode()
# Usage
response = https_request('example.com', '/')
print(response)
Using Python Requests Library
import requests
# Basic HTTPS request (verifies certificates by default)
response = requests.get('https://example.com')
# Disable certificate verification (not recommended!)
response = requests.get('https://example.com', verify=False)
# Use custom CA bundle
response = requests.get('https://example.com', verify='/path/to/ca-bundle.crt')
# Client certificate authentication
response = requests.get('https://example.com',
cert=('client.crt', 'client.key'))
# Specify TLS version
import ssl
from requests.adapters import HTTPAdapter
from urllib3.util.ssl_ import create_urllib3_context
class TLSAdapter(HTTPAdapter):
def init_poolmanager(self, *args, **kwargs):
ctx = create_urllib3_context()
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
ctx.maximum_version = ssl.TLSVersion.TLSv1_3
kwargs['ssl_context'] = ctx
return super().init_poolmanager(*args, **kwargs)
session = requests.Session()
session.mount('https://', TLSAdapter())
response = session.get('https://example.com')
Testing TLS Configuration
OpenSSL Command-Line Tests
# Connect to server and show TLS info
openssl s_client -connect example.com:443 -servername example.com
# Test specific TLS version
openssl s_client -connect example.com:443 -tls1_2
openssl s_client -connect example.com:443 -tls1_3
# Test if old protocols are disabled
openssl s_client -connect example.com:443 -ssl3 # Should fail
openssl s_client -connect example.com:443 -tls1 # Should fail
openssl s_client -connect example.com:443 -tls1_1 # Should fail
# Test specific cipher suite
openssl s_client -connect example.com:443 -cipher 'ECDHE-RSA-AES128-GCM-SHA256'
# Show certificate chain
openssl s_client -connect example.com:443 -showcerts
# Check OCSP stapling
openssl s_client -connect example.com:443 -status
# Check certificate expiration
echo | openssl s_client -connect example.com:443 2>/dev/null | openssl x509 -noout -dates
# Full connection info
openssl s_client -connect example.com:443 -servername example.com </dev/null | grep -E 'Protocol|Cipher'
Testing Tools
nmap
# Scan TLS versions
nmap --script ssl-enum-ciphers -p 443 example.com
# Check for vulnerabilities
nmap --script ssl-* -p 443 example.com
testssl.sh
# Install
git clone https://github.com/drwetter/testssl.sh.git
cd testssl.sh
# Run comprehensive test
./testssl.sh https://example.com
# Test specific features
./testssl.sh --protocols https://example.com
./testssl.sh --ciphers https://example.com
./testssl.sh --vulnerabilities https://example.com
SSL Labs
# Online tool (web interface)
# https://www.ssllabs.com/ssltest/
# API
curl "https://api.ssllabs.com/api/v3/analyze?host=example.com"
Python TLS Testing
import ssl
import socket
def test_tls_version(hostname, port=443):
"""Test TLS versions supported by server"""
versions = {
'TLS 1.0': ssl.PROTOCOL_TLSv1,
'TLS 1.1': ssl.PROTOCOL_TLSv1_1,
'TLS 1.2': ssl.PROTOCOL_TLSv1_2,
'TLS 1.3': ssl.PROTOCOL_TLS, # Tries highest available
}
for version_name, protocol in versions.items():
try:
context = ssl.SSLContext(protocol)
with socket.create_connection((hostname, port), timeout=5) as sock:
with context.wrap_socket(sock, server_hostname=hostname) as ssock:
print(f"✓ {version_name}: Supported (cipher: {ssock.cipher()[0]})")
except Exception as e:
print(f"✗ {version_name}: Not supported")
def get_certificate_info(hostname, port=443):
"""Get server certificate information"""
context = ssl.create_default_context()
with socket.create_connection((hostname, port)) as sock:
with context.wrap_socket(sock, server_hostname=hostname) as ssock:
cert = ssock.getpeercert()
print(f"Subject: {dict(x[0] for x in cert['subject'])}")
print(f"Issuer: {dict(x[0] for x in cert['issuer'])}")
print(f"Version: {cert['version']}")
print(f"Serial: {cert['serialNumber']}")
print(f"Not Before: {cert['notBefore']}")
print(f"Not After: {cert['notAfter']}")
print(f"SANs: {', '.join([x[1] for x in cert.get('subjectAltName', [])])}")
print(f"TLS Version: {ssock.version()}")
print(f"Cipher: {ssock.cipher()}")
# Usage
test_tls_version('example.com')
print()
get_certificate_info('example.com')
Common TLS Vulnerabilities
1. POODLE (Padding Oracle On Downgraded Legacy Encryption)
Attack: Forces downgrade to SSL 3.0, exploits CBC padding
Mitigation:
# Disable SSL 3.0
ssl_protocols TLSv1.2 TLSv1.3;
2. BEAST (Browser Exploit Against SSL/TLS)
Attack: Exploits CBC mode in TLS 1.0
Mitigation:
# Disable TLS 1.0, use modern cipher suites
ssl_protocols TLSv1.2 TLSv1.3;
ssl_ciphers 'ECDHE-RSA-AES128-GCM-SHA256:...';
3. CRIME (Compression Ratio Info-leak Made Easy)
Attack: Exploits TLS compression
Mitigation:
# Disable TLS compression (usually disabled by default)
ssl_compression off;
4. Heartbleed
Attack: Buffer over-read in OpenSSL heartbeat extension
Mitigation:
# Update OpenSSL
sudo apt-get update
sudo apt-get upgrade openssl
# Check version (must be > 1.0.1g)
openssl version
5. Logjam
Attack: Weakness in DHE key exchange with small primes
Mitigation:
# Generate strong DH parameters
openssl dhparam -out /etc/nginx/dhparam.pem 2048
# Configure nginx
ssl_dhparam /etc/nginx/dhparam.pem;
6. FREAK (Factoring RSA Export Keys)
Attack: Forces use of weak export-grade encryption
Mitigation:
# Disable export ciphers
ssl_ciphers 'ECDHE-RSA-AES128-GCM-SHA256:...'; # No EXPORT
7. DROWN (Decrypting RSA with Obsolete and Weakened eNcryption)
Attack: Exploits SSLv2 to break TLS
Mitigation:
# Ensure SSLv2 is disabled everywhere
# Check with:
nmap --script ssl-enum-ciphers -p 443 example.com
Checking for Vulnerabilities
# Using testssl.sh
./testssl.sh --vulnerabilities https://example.com
# Using nmap
nmap --script ssl-heartbleed,ssl-poodle,ssl-dh-params -p 443 example.com
Best Practices
1. Protocol Configuration
# ✓ GOOD - Only modern protocols
ssl_protocols TLSv1.2 TLSv1.3;
# ✗ BAD - Includes old protocols
ssl_protocols TLSv1 TLSv1.1 TLSv1.2 TLSv1.3;
2. Cipher Suite Selection
# ✓ GOOD - Strong, forward-secret ciphers
ssl_ciphers 'ECDHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-CHACHA20-POLY1305';
# ✗ BAD - Includes weak ciphers
ssl_ciphers 'ALL:!aNULL:!MD5';
3. Certificate Management
# ✓ Use certificates from trusted CA (Let's Encrypt)
# ✓ Automate renewal
# ✓ Monitor expiration
# ✓ Include full certificate chain
# ✗ Don't use self-signed in production
# ✗ Don't let certificates expire
4. HSTS (HTTP Strict Transport Security)
# Enforce HTTPS for all subdomains
add_header Strict-Transport-Security "max-age=31536000; includeSubDomains; preload" always;
5. OCSP Stapling
# Enable OCSP stapling for faster certificate validation
ssl_stapling on;
ssl_stapling_verify on;
ssl_trusted_certificate /path/to/chain.pem;
6. Session Management
# Session resumption (performance)
ssl_session_timeout 1d;
ssl_session_cache shared:SSL:50m;
# Disable session tickets (forward secrecy)
ssl_session_tickets off;
7. Perfect Forward Secrecy
Use ECDHE or DHE key exchange:
- ECDHE: Fast, modern
- DHE: Slower, but compatible
Avoid RSA key exchange (no forward secrecy)
8. Regular Updates
# Keep OpenSSL updated
sudo apt-get update
sudo apt-get upgrade openssl libssl-dev
# Keep web server updated
sudo apt-get upgrade nginx # or apache2
9. Monitoring and Testing
# Regular security scans
./testssl.sh https://example.com
# Monitor certificate expiration
curl https://crt.sh/?q=example.com
# Check SSL Labs rating
curl "https://api.ssllabs.com/api/v3/analyze?host=example.com"
Security Checklist
Certificate:
[✓] Valid and not expired
[✓] From trusted CA
[✓] Matches domain name
[✓] Includes full chain
[✓] Strong key (RSA 2048+ or ECDSA P-256+)
Protocol:
[✓] TLS 1.2 minimum
[✓] TLS 1.3 enabled
[✓] SSL 3.0 disabled
[✓] TLS 1.0/1.1 disabled
Cipher Suites:
[✓] Only strong ciphers
[✓] Forward secrecy (ECDHE)
[✓] AEAD modes (GCM, ChaCha20-Poly1305)
[✓] No weak ciphers (RC4, 3DES, etc.)
Headers:
[✓] HSTS enabled
[✓] Secure cookie flags
Features:
[✓] OCSP stapling enabled
[✓] Session tickets disabled
[✓] HTTP → HTTPS redirect
Vulnerabilities:
[✓] Not vulnerable to POODLE
[✓] Not vulnerable to BEAST
[✓] Not vulnerable to Heartbleed
[✓] Not vulnerable to Logjam
[✓] Not vulnerable to FREAK
[✓] Not vulnerable to DROWN
Common Mistakes
1. Mixed Content
<!-- BAD - Loading HTTP resource on HTTPS page -->
<script src="http://example.com/script.js"></script>
<!-- GOOD - Use HTTPS -->
<script src="https://example.com/script.js"></script>
<!-- BETTER - Protocol-relative URL -->
<script src="//example.com/script.js"></script>
2. Weak Cipher Configuration
# BAD - Allows weak ciphers
ssl_ciphers 'ALL:!aNULL:!MD5';
# GOOD - Only strong ciphers
ssl_ciphers 'ECDHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384';
3. Missing Certificate Chain
# BAD - Only server certificate
ssl_certificate /path/to/cert.pem;
# GOOD - Full chain (server + intermediate)
ssl_certificate /path/to/fullchain.pem;
4. Expired Certificates
# Check expiration regularly
echo | openssl s_client -connect example.com:443 2>/dev/null | openssl x509 -noout -dates
# Automate renewal
certbot renew
5. Not Redirecting HTTP to HTTPS
# Missing HTTP → HTTPS redirect leaves users vulnerable
# GOOD - Redirect all HTTP to HTTPS
server {
listen 80;
server_name example.com;
return 301 https://$server_name$request_uri;
}
ELI10
TLS is like a secure tunnel for internet communication:
Without TLS (HTTP):
You: "My password is abc123"
↓ (anyone can read this!)
Server: "OK, logged in"
Bad guys can see everything!
With TLS (HTTPS):
Step 1: Build a secure tunnel
You: "Let's talk securely!"
Server: "Here's my ID card" (certificate)
You: "OK, I trust you"
Both: [Create secret code together]
Step 2: Talk through tunnel
You: "xf9#k2@..." (encrypted password)
↓ (looks like gibberish to bad guys!)
Server: "p8#nz..." (encrypted response)
The Handshake (making friends):
- You: "Hi! I speak TLS 1.2 and TLS 1.3"
- Server: "Great! Let's use TLS 1.3. Here's my ID card"
- You: "ID looks good! Here's a secret number"
- Server: "Got it! Here's my secret number"
- Both: "Let's mix our secrets to make a key!"
- Both: "Tunnel ready! Let's talk!"
Why it's secure:
- Encryption: Messages look like random gibberish
- Authentication: Server proves it's really who it claims to be
- Integrity: Detect if someone changes messages
TLS 1.3 is better:
- Faster (1 handshake step instead of 2)
- More secure (removed old, weak options)
- Simpler (fewer choices = fewer mistakes)
Real-world analogy:
- HTTP = Postcard (anyone can read it)
- HTTPS = Sealed letter with signature (secure and verified)
Further Resources
- TLS 1.3 RFC 8446
- TLS 1.2 RFC 5246
- Mozilla SSL Configuration Generator
- SSL Labs Server Test
- testssl.sh GitHub
- OWASP TLS Cheat Sheet
- High Performance Browser Networking (Book)
- TLS Illustrated
- Cloudflare TLS 1.3 Guide
HMAC (Hash-based Message Authentication Code)
Overview
HMAC is a mechanism for message authentication using cryptographic hash functions. It provides both data integrity (message hasn't been altered) and authentication (message came from someone with the secret key).
HMAC Construction
Formula
HMAC(K, m) = H((K' ⊕ opad) || H((K' ⊕ ipad) || m))
Where:
- K = secret key
- m = message
- H = cryptographic hash function (SHA-256, SHA-512, etc.)
- K' = key derived from K (padded/hashed to block size)
- ⊕ = XOR operation
- || = concatenation
- opad = outer padding (0x5c repeated)
- ipad = inner padding (0x36 repeated)
Simplified Steps
1. If key is longer than block size, hash it
2. If key is shorter than block size, pad with zeros
3. XOR key with inner padding (ipad)
4. Append message to result
5. Hash the result (inner hash)
6. XOR key with outer padding (opad)
7. Append inner hash to result
8. Hash the result (outer hash) = HMAC
Visual Representation
Secret Key
|
+------+------+
| |
XOR ipad XOR opad
| |
+ Message |
| |
Hash (inner) |
| |
+-------------+
|
Hash (outer)
|
HMAC
Why HMAC Instead of Hash(Key + Message)?
Vulnerable Approaches
# VULNERABLE 1: Simple concatenation
tag = sha256(key + message)
# Vulnerable to length extension attacks!
# VULNERABLE 2: Wrong order
tag = sha256(message + key)
# Attacker can append data!
# SECURE: Use HMAC
tag = hmac.new(key, message, sha256).digest()
Length Extension Attack Example
# With SHA-256 concatenation (VULNERABLE)
original = sha256(key + message)
# Attacker can compute: sha256(key + message + attacker_data)
# WITHOUT knowing the key!
# With HMAC (SECURE)
original = hmac(key, message)
# Attacker CANNOT extend the message without knowing the key
Using HMAC
Python Examples
Basic HMAC
import hmac
import hashlib
# Create HMAC
key = b"secret-key-12345"
message = b"Important message"
# HMAC-SHA256
mac = hmac.new(key, message, hashlib.sha256)
tag = mac.hexdigest()
print(f"HMAC-SHA256: {tag}")
# HMAC-SHA512
mac = hmac.new(key, message, hashlib.sha512)
tag = mac.hexdigest()
print(f"HMAC-SHA512: {tag}")
# Digest as bytes
tag_bytes = hmac.new(key, message, hashlib.sha256).digest()
print(f"HMAC (bytes): {tag_bytes}")
Verify HMAC
import hmac
import hashlib
def create_hmac(key, message):
return hmac.new(key, message, hashlib.sha256).digest()
def verify_hmac(key, message, received_tag):
expected_tag = hmac.new(key, message, hashlib.sha256).digest()
# Use constant-time comparison to prevent timing attacks
return hmac.compare_digest(expected_tag, received_tag)
# Usage
key = b"secret-key"
message = b"Transfer $100 to Alice"
# Create tag
tag = create_hmac(key, message)
print(f"Tag: {tag.hex()}")
# Verify tag (correct)
if verify_hmac(key, message, tag):
print("Message is authentic!")
# Verify tag (tampered message)
tampered = b"Transfer $999 to Alice"
if not verify_hmac(key, tampered, tag):
print("Message has been tampered with!")
Incremental HMAC
import hmac
import hashlib
# For large messages
mac = hmac.new(b"secret-key", digestmod=hashlib.sha256)
# Update incrementally
mac.update(b"Part 1 of message ")
mac.update(b"Part 2 of message ")
mac.update(b"Part 3 of message")
tag = mac.hexdigest()
print(f"Incremental HMAC: {tag}")
# Equivalent to
mac_full = hmac.new(b"secret-key",
b"Part 1 of message Part 2 of message Part 3 of message",
hashlib.sha256)
print(f"Full HMAC: {mac_full.hexdigest()}")
OpenSSL/Bash Examples
# Generate HMAC-SHA256
echo -n "Important message" | openssl dgst -sha256 -hmac "secret-key"
# HMAC-SHA512
echo -n "Important message" | openssl dgst -sha512 -hmac "secret-key"
# HMAC of a file
openssl dgst -sha256 -hmac "secret-key" document.pdf
# Output in different formats
echo -n "message" | openssl dgst -sha256 -hmac "key" -hex
echo -n "message" | openssl dgst -sha256 -hmac "key" -binary | base64
JavaScript Example
const crypto = require('crypto');
// Create HMAC
const key = 'secret-key-12345';
const message = 'Important message';
const hmac = crypto.createHmac('sha256', key);
hmac.update(message);
const tag = hmac.digest('hex');
console.log(`HMAC-SHA256: ${tag}`);
// Verify HMAC
function verifyHMAC(key, message, receivedTag) {
const expectedTag = crypto.createHmac('sha256', key)
.update(message)
.digest('hex');
// Constant-time comparison
return crypto.timingSafeEqual(
Buffer.from(expectedTag, 'hex'),
Buffer.from(receivedTag, 'hex')
);
}
Message Authentication
Sending Authenticated Messages
import hmac
import hashlib
import json
class AuthenticatedMessage:
def __init__(self, shared_key):
self.key = shared_key
def send(self, message):
# Create HMAC tag
tag = hmac.new(self.key, message.encode(), hashlib.sha256).hexdigest()
# Package message with tag
package = {
'message': message,
'hmac': tag
}
return json.dumps(package)
def receive(self, package_json):
# Unpack message
package = json.loads(package_json)
message = package['message']
received_tag = package['hmac']
# Verify HMAC
expected_tag = hmac.new(self.key, message.encode(), hashlib.sha256).hexdigest()
if hmac.compare_digest(expected_tag, received_tag):
return message, True
else:
return None, False
# Usage
shared_key = b"shared-secret-key-between-alice-and-bob"
# Alice sends message
alice = AuthenticatedMessage(shared_key)
package = alice.send("Transfer $100 to Bob")
print(f"Sent: {package}")
# Bob receives message
bob = AuthenticatedMessage(shared_key)
message, is_authentic = bob.receive(package)
if is_authentic:
print(f"Authentic message: {message}")
else:
print("Warning: Message tampered!")
# Attacker tries to tamper
tampered_package = package.replace("$100", "$999")
message, is_authentic = bob.receive(tampered_package)
print(f"Tampered authentic: {is_authentic}") # False
Integrity Verification
File Integrity with HMAC
import hmac
import hashlib
import os
class FileIntegrityChecker:
def __init__(self, key):
self.key = key
def compute_file_hmac(self, filepath):
mac = hmac.new(self.key, digestmod=hashlib.sha256)
with open(filepath, 'rb') as f:
while chunk := f.read(8192):
mac.update(chunk)
return mac.hexdigest()
def create_manifest(self, files):
manifest = {}
for filepath in files:
manifest[filepath] = self.compute_file_hmac(filepath)
return manifest
def verify_files(self, manifest):
results = {}
for filepath, expected_hmac in manifest.items():
if not os.path.exists(filepath):
results[filepath] = "MISSING"
else:
actual_hmac = self.compute_file_hmac(filepath)
if hmac.compare_digest(expected_hmac, actual_hmac):
results[filepath] = "OK"
else:
results[filepath] = "MODIFIED"
return results
# Usage
checker = FileIntegrityChecker(b"integrity-check-key")
# Create manifest
files = ['config.json', 'app.py', 'data.db']
manifest = checker.create_manifest(files)
print("Manifest created:", manifest)
# Later, verify files
results = checker.verify_files(manifest)
for file, status in results.items():
print(f"{file}: {status}")
API Authentication
API Request Signing
import hmac
import hashlib
import time
import requests
from urllib.parse import urlencode
class APIClient:
def __init__(self, api_key, api_secret):
self.api_key = api_key
self.api_secret = api_secret.encode()
def generate_signature(self, method, path, params):
# Create string to sign
timestamp = str(int(time.time()))
params['timestamp'] = timestamp
params['api_key'] = self.api_key
# Sort parameters
sorted_params = sorted(params.items())
query_string = urlencode(sorted_params)
# String to sign: METHOD + PATH + QUERY_STRING
message = f"{method}{path}{query_string}"
# Generate HMAC signature
signature = hmac.new(
self.api_secret,
message.encode(),
hashlib.sha256
).hexdigest()
return signature, timestamp
def make_request(self, method, path, params=None):
if params is None:
params = {}
# Generate signature
signature, timestamp = self.generate_signature(method, path, params)
# Add authentication headers
headers = {
'X-API-Key': self.api_key,
'X-API-Signature': signature,
'X-API-Timestamp': timestamp
}
# Make request
url = f"https://api.example.com{path}"
response = requests.request(method, url, params=params, headers=headers)
return response
# Server-side verification
class APIServer:
def __init__(self):
# In practice, look up secret from database based on API key
self.api_secrets = {
'key123': b'secret123'
}
def verify_signature(self, api_key, signature, timestamp, method, path, params):
# Check timestamp (prevent replay attacks)
current_time = int(time.time())
request_time = int(timestamp)
if abs(current_time - request_time) > 300: # 5 minutes
return False, "Request expired"
# Get API secret
if api_key not in self.api_secrets:
return False, "Invalid API key"
api_secret = self.api_secrets[api_key]
# Reconstruct signed message
params['timestamp'] = timestamp
params['api_key'] = api_key
sorted_params = sorted(params.items())
query_string = urlencode(sorted_params)
message = f"{method}{path}{query_string}"
# Compute expected signature
expected_signature = hmac.new(
api_secret,
message.encode(),
hashlib.sha256
).hexdigest()
# Compare signatures (constant time)
if hmac.compare_digest(expected_signature, signature):
return True, "Valid"
else:
return False, "Invalid signature"
# Usage
client = APIClient('key123', 'secret123')
response = client.make_request('GET', '/api/users', {'limit': 10})
REST API with HMAC Authentication
from flask import Flask, request, jsonify
import hmac
import hashlib
app = Flask(__name__)
API_SECRETS = {
'client1': b'secret1',
'client2': b'secret2'
}
def verify_hmac_signature():
api_key = request.headers.get('X-API-Key')
signature = request.headers.get('X-Signature')
if not api_key or not signature:
return False
if api_key not in API_SECRETS:
return False
# Reconstruct signed data
# Method + Path + Body (for POST/PUT)
data = request.method + request.path
if request.data:
data += request.data.decode()
# Compute expected signature
expected = hmac.new(
API_SECRETS[api_key],
data.encode(),
hashlib.sha256
).hexdigest()
return hmac.compare_digest(expected, signature)
@app.route('/api/data', methods=['POST'])
def post_data():
if not verify_hmac_signature():
return jsonify({'error': 'Unauthorized'}), 401
# Process request
data = request.json
return jsonify({'status': 'success', 'data': data})
# Client request example
import requests
import hmac
import hashlib
api_key = 'client1'
api_secret = b'secret1'
url = 'http://localhost:5000/api/data'
payload = {'key': 'value'}
# Create signature
data = 'POST' + '/api/data' + json.dumps(payload)
signature = hmac.new(api_secret, data.encode(), hashlib.sha256).hexdigest()
headers = {
'X-API-Key': api_key,
'X-Signature': signature,
'Content-Type': 'application/json'
}
response = requests.post(url, json=payload, headers=headers)
JWT (JSON Web Tokens)
JWTs use HMAC (or RSA) for signature verification.
JWT Structure
eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c
| | | |
Header Payload Signature
JWT with HMAC
import hmac
import hashlib
import json
import base64
class JWT:
def __init__(self, secret):
self.secret = secret.encode()
def base64url_encode(self, data):
return base64.urlsafe_b64encode(data).rstrip(b'=').decode()
def base64url_decode(self, data):
padding = 4 - len(data) % 4
data += '=' * padding
return base64.urlsafe_b64decode(data)
def create_token(self, payload):
# Header
header = {
'alg': 'HS256',
'typ': 'JWT'
}
# Encode header and payload
header_encoded = self.base64url_encode(json.dumps(header).encode())
payload_encoded = self.base64url_encode(json.dumps(payload).encode())
# Create signature
message = f"{header_encoded}.{payload_encoded}".encode()
signature = hmac.new(self.secret, message, hashlib.sha256).digest()
signature_encoded = self.base64url_encode(signature)
# Combine
token = f"{header_encoded}.{payload_encoded}.{signature_encoded}"
return token
def verify_token(self, token):
try:
parts = token.split('.')
if len(parts) != 3:
return None, False
header_encoded, payload_encoded, signature_encoded = parts
# Verify signature
message = f"{header_encoded}.{payload_encoded}".encode()
expected_signature = hmac.new(self.secret, message, hashlib.sha256).digest()
received_signature = self.base64url_decode(signature_encoded)
if not hmac.compare_digest(expected_signature, received_signature):
return None, False
# Decode payload
payload = json.loads(self.base64url_decode(payload_encoded))
return payload, True
except Exception as e:
return None, False
# Usage
jwt = JWT('my-secret-key')
# Create token
payload = {
'user_id': 12345,
'username': 'john_doe',
'exp': int(time.time()) + 3600 # Expires in 1 hour
}
token = jwt.create_token(payload)
print(f"JWT: {token}")
# Verify token
payload, is_valid = jwt.verify_token(token)
if is_valid:
print(f"Valid token! User: {payload['username']}")
else:
print("Invalid token!")
# Using PyJWT library (recommended)
import jwt as pyjwt
# Create token
token = pyjwt.encode(payload, 'my-secret-key', algorithm='HS256')
# Verify token
try:
decoded = pyjwt.decode(token, 'my-secret-key', algorithms=['HS256'])
print(f"Valid! Payload: {decoded}")
except pyjwt.InvalidTokenError:
print("Invalid token!")
HMAC vs Other MACs
Comparison
| Feature | HMAC | CBC-MAC | GMAC | Poly1305 |
|---|---|---|---|---|
| Based on | Hash function | Block cipher | Block cipher | Universal hash |
| Performance | Moderate | Slow | Fast | Very fast |
| Key reuse | Safe | Dangerous | Safe | One-time key |
| Standardized | Yes (RFC 2104) | Yes | Yes (GCM) | Yes (ChaCha20) |
| Use case | General purpose | Legacy | AEAD | Modern crypto |
HMAC-SHA256 vs HMAC-SHA512
import hmac
import hashlib
import time
message = b"x" * 1000000 # 1 MB
key = b"secret-key"
# HMAC-SHA256
start = time.time()
for _ in range(100):
hmac.new(key, message, hashlib.sha256).digest()
print(f"HMAC-SHA256: {time.time() - start:.3f}s")
# HMAC-SHA512
start = time.time()
for _ in range(100):
hmac.new(key, message, hashlib.sha512).digest()
print(f"HMAC-SHA512: {time.time() - start:.3f}s")
# Output sizes
print(f"SHA256 output: {len(hmac.new(key, b'test', hashlib.sha256).digest())} bytes")
print(f"SHA512 output: {len(hmac.new(key, b'test', hashlib.sha512).digest())} bytes")
Security Considerations
1. Key Length
# Minimum key length = hash output size
# SHA-256: minimum 32 bytes
# SHA-512: minimum 64 bytes
# GOOD
key = os.urandom(32) # 256 bits for HMAC-SHA256
# BAD - too short
key = b"secret" # Only 48 bits!
# Better - derive from password
from hashlib import pbkdf2_hmac
key = pbkdf2_hmac('sha256', b'user-password', b'salt', 100000)
2. Constant-Time Comparison
# VULNERABLE - timing attack
if computed_hmac == received_hmac:
return True
# SECURE - constant time comparison
import hmac
if hmac.compare_digest(computed_hmac, received_hmac):
return True
3. Prevent Replay Attacks
import time
def verify_request(hmac_tag, timestamp, max_age=300):
# Verify HMAC first
if not verify_hmac(hmac_tag):
return False
# Check timestamp (prevent replays)
current_time = int(time.time())
request_time = int(timestamp)
if abs(current_time - request_time) > max_age:
return False # Request too old
# Optional: Track used nonces to prevent replay
# if nonce in used_nonces:
# return False
return True
4. Use Separate Keys
# BAD - same key for different purposes
encryption_key = b"shared-key"
hmac_key = b"shared-key"
# GOOD - derive separate keys
from hashlib import sha256
master_key = b"master-secret-key"
encryption_key = sha256(master_key + b"encryption").digest()
hmac_key = sha256(master_key + b"authentication").digest()
# BETTER - use HKDF (HMAC-based Key Derivation)
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
master_key = b"master-secret-key"
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=None,
info=b'encryption',
)
encryption_key = hkdf.derive(master_key)
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=None,
info=b'authentication',
)
hmac_key = hkdf.derive(master_key)
5. Truncation
# Full HMAC (recommended)
mac = hmac.new(key, message, hashlib.sha256).digest() # 32 bytes
# Truncated HMAC (if needed)
mac_truncated = hmac.new(key, message, hashlib.sha256).digest()[:16] # 16 bytes
# Minimum recommended: 128 bits (16 bytes)
# Never go below 80 bits (10 bytes)
Best Practices
1. Always Use HMAC for Message Authentication
# ✓ Use HMAC
tag = hmac.new(key, message, hashlib.sha256).digest()
# ✗ Don't use simple hash
tag = hashlib.sha256(key + message).digest() # Vulnerable!
2. Choose Appropriate Hash Function
# Modern: SHA-256 or SHA-512
hmac.new(key, message, hashlib.sha256)
# Avoid: MD5 or SHA-1
hmac.new(key, message, hashlib.md5) # Don't use!
3. Protect the Key
# Store keys securely
# - Use environment variables
# - Use key management service (AWS KMS, etc.)
# - Never hardcode in source code
# - Never commit to version control
import os
key = os.environ.get('HMAC_KEY').encode()
# Rotate keys periodically
# Support multiple active keys during rotation
4. Include All Relevant Data
# Sign complete context
data = {
'timestamp': timestamp,
'user_id': user_id,
'action': action,
'nonce': nonce
}
message = json.dumps(data, sort_keys=True).encode()
signature = hmac.new(key, message, hashlib.sha256).hexdigest()
Common Mistakes
1. Using == for Comparison
# WRONG - timing attack
if hmac1 == hmac2:
pass
# RIGHT - constant time
if hmac.compare_digest(hmac1, hmac2):
pass
2. Not Including Timestamp
# WRONG - vulnerable to replay
signature = hmac.new(key, message, sha256).hexdigest()
# RIGHT - include timestamp
data = f"{timestamp}:{message}"
signature = hmac.new(key, data.encode(), sha256).hexdigest()
3. Wrong Key Derivation
# WRONG - weak key
key = b"password"
# RIGHT - derive from password
from hashlib import pbkdf2_hmac
key = pbkdf2_hmac('sha256', b'password', b'salt', 100000)
ELI10
HMAC is like a secret handshake for messages:
Imagine you and your best friend have a secret code:
- You write a message: "Meet at the treehouse at 3pm"
- You add your secret code and mix it all together in a special way
- You get a "stamp":
a7f9e4b2... - You send: message + stamp
When your friend receives it:
- They take the message
- They add the SAME secret code and mix it the SAME way
- They get their own stamp
- If their stamp matches yours, they know:
- The message really came from you (only you know the code!)
- Nobody changed the message (the stamp would be different!)
Why not just put the secret code in the message?
- Anyone could copy your code!
Why not just hash the message?
- Anyone could make their own hash!
HMAC is special because:
- You need the secret code to make the stamp
- Even a tiny change makes a completely different stamp
- Nobody can make the right stamp without knowing your secret code!
Real-world example: When you log into a website, your browser and the server use HMAC to:
- Make sure messages aren't tampered with
- Prove who sent the message
- Keep your session secure!
Further Resources
- RFC 2104 - HMAC Specification
- HMAC Security Analysis
- JWT Specification (RFC 7519)
- API Authentication Best Practices
- Timing Attack Prevention
- HKDF Specification (RFC 5869)
OAuth 2.0
OAuth 2.0 is an industry-standard authorization framework that enables applications to obtain limited access to user accounts on an HTTP service. It works by delegating user authentication to the service that hosts the user account and authorizing third-party applications to access the user account.
Table of Contents
- Introduction
- OAuth 2.0 Roles
- Grant Types
- Authorization Code Flow
- Client Credentials Flow
- Implementing OAuth 2.0
- OAuth 2.0 Providers
- Security Best Practices
- Common Vulnerabilities
Introduction
What is OAuth 2.0? OAuth 2.0 is an authorization framework, not an authentication protocol. It allows users to grant limited access to their resources on one site to another site, without sharing their credentials.
Key Benefits:
- Users don't share passwords with third-party apps
- Fine-grained access control (scopes)
- Time-limited access through tokens
- Revocable access
- Industry standard with wide support
Use Cases:
- Social login (Sign in with Google, Facebook, etc.)
- API access delegation
- Third-party application integration
- Microservices authentication
- Mobile app authentication
OAuth 2.0 Roles
1. Resource Owner
The user who owns the data and can grant access to it.
Example: John who has a Google account with photos
2. Client
The application requesting access to resources.
Example: A photo printing service that wants access to John's photos
3. Authorization Server
Server that authenticates the resource owner and issues access tokens.
Example: Google's OAuth 2.0 authorization server
4. Resource Server
Server hosting the protected resources.
Example: Google Photos API server
Grant Types
1. Authorization Code Flow
Best for: Server-side web applications
Flow:
1. Client redirects user to authorization server
2. User authenticates and grants permission
3. Authorization server redirects back with authorization code
4. Client exchanges code for access token
5. Client uses access token to access resources
Benefits:
- Most secure flow
- Refresh tokens supported
- Client secret never exposed to browser
2. Implicit Flow (Deprecated)
Status: Not recommended for new applications
Flow:
1. Client redirects user to authorization server
2. User authenticates and grants permission
3. Authorization server redirects with access token in URL fragment
Issues:
- Token exposed in browser history
- No refresh token
- Less secure
3. Client Credentials Flow
Best for: Server-to-server communication
Flow:
1. Client authenticates with client_id and client_secret
2. Authorization server returns access token
3. Client uses access token for API calls
Use cases:
- Microservices communication
- Batch jobs
- CLI tools
4. Resource Owner Password Credentials (Not Recommended)
Flow:
1. User provides username and password to client
2. Client exchanges credentials for access token
Issues:
- User shares credentials with client
- Defeats OAuth purpose
- Only for legacy systems
5. PKCE (Proof Key for Code Exchange)
Best for: Mobile and SPA applications
Enhancement to Authorization Code Flow:
1. Client generates code_verifier (random string)
2. Client creates code_challenge = hash(code_verifier)
3. Authorization request includes code_challenge
4. Token request includes code_verifier
5. Server verifies code_challenge matches code_verifier
Benefits:
- Protects against authorization code interception
- No client secret needed
- Secure for public clients
Authorization Code Flow
Step-by-Step Implementation
Step 1: Authorization Request
GET /authorize?
response_type=code&
client_id=YOUR_CLIENT_ID&
redirect_uri=https://yourapp.com/callback&
scope=read:user read:email&
state=random_string
HTTP/1.1
Host: authorization-server.com
Parameters:
response_type: Set to "code"client_id: Your application's client IDredirect_uri: Where to redirect after authorizationscope: Requested permissionsstate: Random string to prevent CSRF
Step 2: User Authorization
User sees consent screen and approves/denies access.
Step 3: Authorization Response
HTTP/1.1 302 Found
Location: https://yourapp.com/callback?
code=AUTH_CODE&
state=random_string
Step 4: Token Request
POST /token HTTP/1.1
Host: authorization-server.com
Content-Type: application/x-www-form-urlencoded
grant_type=authorization_code&
code=AUTH_CODE&
redirect_uri=https://yourapp.com/callback&
client_id=YOUR_CLIENT_ID&
client_secret=YOUR_CLIENT_SECRET
Step 5: Token Response
{
"access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "refresh_token_here",
"scope": "read:user read:email"
}
Step 6: Using Access Token
GET /api/user HTTP/1.1
Host: api.example.com
Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
Node.js Implementation (Express)
const express = require('express');
const axios = require('axios');
const crypto = require('crypto');
const app = express();
const CLIENT_ID = process.env.CLIENT_ID;
const CLIENT_SECRET = process.env.CLIENT_SECRET;
const REDIRECT_URI = 'http://localhost:3000/callback';
const AUTHORIZATION_URL = 'https://authorization-server.com/authorize';
const TOKEN_URL = 'https://authorization-server.com/token';
// Step 1: Initiate authorization
app.get('/login', (req, res) => {
const state = crypto.randomBytes(16).toString('hex');
req.session.state = state;
const authUrl = new URL(AUTHORIZATION_URL);
authUrl.searchParams.append('response_type', 'code');
authUrl.searchParams.append('client_id', CLIENT_ID);
authUrl.searchParams.append('redirect_uri', REDIRECT_URI);
authUrl.searchParams.append('scope', 'read:user read:email');
authUrl.searchParams.append('state', state);
res.redirect(authUrl.toString());
});
// Step 2: Handle callback
app.get('/callback', async (req, res) => {
const { code, state } = req.query;
// Verify state
if (state !== req.session.state) {
return res.status(400).send('Invalid state');
}
try {
// Exchange code for token
const tokenResponse = await axios.post(TOKEN_URL, {
grant_type: 'authorization_code',
code,
redirect_uri: REDIRECT_URI,
client_id: CLIENT_ID,
client_secret: CLIENT_SECRET,
});
const { access_token, refresh_token } = tokenResponse.data;
// Store tokens securely
req.session.access_token = access_token;
req.session.refresh_token = refresh_token;
res.redirect('/dashboard');
} catch (error) {
res.status(500).send('Authentication failed');
}
});
// Step 3: Use access token
app.get('/api/user', async (req, res) => {
const { access_token } = req.session;
if (!access_token) {
return res.status(401).send('Not authenticated');
}
try {
const userResponse = await axios.get('https://api.example.com/user', {
headers: {
Authorization: `Bearer ${access_token}`,
},
});
res.json(userResponse.data);
} catch (error) {
if (error.response?.status === 401) {
// Token expired, refresh it
return res.redirect('/refresh');
}
res.status(500).send('Failed to fetch user');
}
});
// Refresh token
app.get('/refresh', async (req, res) => {
const { refresh_token } = req.session;
try {
const tokenResponse = await axios.post(TOKEN_URL, {
grant_type: 'refresh_token',
refresh_token,
client_id: CLIENT_ID,
client_secret: CLIENT_SECRET,
});
req.session.access_token = tokenResponse.data.access_token;
res.redirect('/dashboard');
} catch (error) {
res.redirect('/login');
}
});
Client Credentials Flow
Implementation Example
const axios = require('axios');
async function getAccessToken() {
const response = await axios.post(
'https://authorization-server.com/token',
{
grant_type: 'client_credentials',
client_id: process.env.CLIENT_ID,
client_secret: process.env.CLIENT_SECRET,
scope: 'api:read api:write',
},
{
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
}
);
return response.data.access_token;
}
async function callAPI() {
const token = await getAccessToken();
const apiResponse = await axios.get('https://api.example.com/data', {
headers: {
Authorization: `Bearer ${token}`,
},
});
return apiResponse.data;
}
// Usage
callAPI()
.then(data => console.log(data))
.catch(error => console.error(error));
Implementing OAuth 2.0
Building an OAuth 2.0 Server
Using Node.js with oauth2-server:
npm install express oauth2-server
server.js:
const express = require('express');
const OAuth2Server = require('oauth2-server');
const Request = OAuth2Server.Request;
const Response = OAuth2Server.Response;
const app = express();
// OAuth2 model
const model = {
getClient: async (clientId, clientSecret) => {
// Fetch client from database
const client = await db.clients.findOne({ clientId });
if (!client || (clientSecret && client.clientSecret !== clientSecret)) {
return null;
}
return {
id: client.id,
grants: ['authorization_code', 'refresh_token'],
redirectUris: client.redirectUris,
};
},
saveToken: async (token, client, user) => {
// Save token to database
await db.tokens.create({
accessToken: token.accessToken,
accessTokenExpiresAt: token.accessTokenExpiresAt,
refreshToken: token.refreshToken,
refreshTokenExpiresAt: token.refreshTokenExpiresAt,
client: client.id,
user: user.id,
});
return token;
},
getAccessToken: async (accessToken) => {
const token = await db.tokens.findOne({ accessToken });
if (!token) return null;
return {
accessToken: token.accessToken,
accessTokenExpiresAt: token.accessTokenExpiresAt,
client: { id: token.client },
user: { id: token.user },
};
},
getAuthorizationCode: async (authorizationCode) => {
const code = await db.authCodes.findOne({ code: authorizationCode });
if (!code) return null;
return {
code: code.code,
expiresAt: code.expiresAt,
redirectUri: code.redirectUri,
client: { id: code.client },
user: { id: code.user },
};
},
saveAuthorizationCode: async (code, client, user) => {
await db.authCodes.create({
code: code.authorizationCode,
expiresAt: code.expiresAt,
redirectUri: code.redirectUri,
client: client.id,
user: user.id,
});
return code;
},
revokeAuthorizationCode: async (code) => {
await db.authCodes.delete({ code: code.code });
return true;
},
verifyScope: async (token, scope) => {
if (!token.scope) return false;
const requestedScopes = scope.split(' ');
const authorizedScopes = token.scope.split(' ');
return requestedScopes.every(s => authorizedScopes.includes(s));
},
};
const oauth = new OAuth2Server({
model: model,
accessTokenLifetime: 3600,
allowBearerTokensInQueryString: true,
});
// Authorization endpoint
app.get('/authorize', async (req, res) => {
const request = new Request(req);
const response = new Response(res);
try {
// Authenticate user (implement your own logic)
const user = await authenticateUser(req);
if (!user) {
return res.redirect('/login');
}
const code = await oauth.authorize(request, response, {
authenticateHandler: {
handle: () => user,
},
});
res.redirect(`${code.redirectUri}?code=${code.authorizationCode}&state=${req.query.state}`);
} catch (error) {
res.status(error.code || 500).json(error);
}
});
// Token endpoint
app.post('/token', async (req, res) => {
const request = new Request(req);
const response = new Response(res);
try {
const token = await oauth.token(request, response);
res.json(token);
} catch (error) {
res.status(error.code || 500).json(error);
}
});
// Protected resource
app.get('/api/resource', async (req, res) => {
const request = new Request(req);
const response = new Response(res);
try {
const token = await oauth.authenticate(request, response);
res.json({ message: 'Protected resource', user: token.user });
} catch (error) {
res.status(error.code || 401).json({ error: 'Unauthorized' });
}
});
OAuth 2.0 Providers
Google OAuth 2.0
const passport = require('passport');
const GoogleStrategy = require('passport-google-oauth20').Strategy;
passport.use(new GoogleStrategy({
clientID: process.env.GOOGLE_CLIENT_ID,
clientSecret: process.env.GOOGLE_CLIENT_SECRET,
callbackURL: "http://localhost:3000/auth/google/callback"
},
function(accessToken, refreshToken, profile, cb) {
// Find or create user in your database
User.findOrCreate({ googleId: profile.id }, function (err, user) {
return cb(err, user);
});
}
));
app.get('/auth/google',
passport.authenticate('google', { scope: ['profile', 'email'] })
);
app.get('/auth/google/callback',
passport.authenticate('google', { failureRedirect: '/login' }),
function(req, res) {
res.redirect('/dashboard');
}
);
GitHub OAuth 2.0
const GitHubStrategy = require('passport-github2').Strategy;
passport.use(new GitHubStrategy({
clientID: process.env.GITHUB_CLIENT_ID,
clientSecret: process.env.GITHUB_CLIENT_SECRET,
callbackURL: "http://localhost:3000/auth/github/callback"
},
function(accessToken, refreshToken, profile, done) {
User.findOrCreate({ githubId: profile.id }, function (err, user) {
return done(err, user);
});
}
));
app.get('/auth/github',
passport.authenticate('github', { scope: [ 'user:email' ] })
);
app.get('/auth/github/callback',
passport.authenticate('github', { failureRedirect: '/login' }),
function(req, res) {
res.redirect('/dashboard');
}
);
Custom OAuth 2.0 Client
class OAuth2Client {
constructor(config) {
this.clientId = config.clientId;
this.clientSecret = config.clientSecret;
this.redirectUri = config.redirectUri;
this.authorizationUrl = config.authorizationUrl;
this.tokenUrl = config.tokenUrl;
}
getAuthorizationUrl(state, scope) {
const params = new URLSearchParams({
response_type: 'code',
client_id: this.clientId,
redirect_uri: this.redirectUri,
scope: scope.join(' '),
state,
});
return `${this.authorizationUrl}?${params.toString()}`;
}
async exchangeCodeForToken(code) {
const response = await fetch(this.tokenUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
body: new URLSearchParams({
grant_type: 'authorization_code',
code,
redirect_uri: this.redirectUri,
client_id: this.clientId,
client_secret: this.clientSecret,
}),
});
if (!response.ok) {
throw new Error('Token exchange failed');
}
return await response.json();
}
async refreshToken(refreshToken) {
const response = await fetch(this.tokenUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
body: new URLSearchParams({
grant_type: 'refresh_token',
refresh_token: refreshToken,
client_id: this.clientId,
client_secret: this.clientSecret,
}),
});
if (!response.ok) {
throw new Error('Token refresh failed');
}
return await response.json();
}
async getUserInfo(accessToken) {
const response = await fetch('https://api.example.com/user', {
headers: {
Authorization: `Bearer ${accessToken}`,
},
});
if (!response.ok) {
throw new Error('Failed to fetch user info');
}
return await response.json();
}
}
// Usage
const client = new OAuth2Client({
clientId: process.env.CLIENT_ID,
clientSecret: process.env.CLIENT_SECRET,
redirectUri: 'http://localhost:3000/callback',
authorizationUrl: 'https://provider.com/authorize',
tokenUrl: 'https://provider.com/token',
});
// Generate authorization URL
const authUrl = client.getAuthorizationUrl('random_state', ['read:user', 'read:email']);
// Exchange code for token
const tokens = await client.exchangeCodeForToken('authorization_code');
// Get user info
const user = await client.getUserInfo(tokens.access_token);
Security Best Practices
1. Always Use HTTPS
All OAuth 2.0 endpoints must use HTTPS to prevent token interception
2. Validate Redirect URIs
function validateRedirectUri(redirectUri, registeredUris) {
return registeredUris.includes(redirectUri);
}
3. Use State Parameter
const state = crypto.randomBytes(32).toString('hex');
req.session.oauthState = state;
// Verify on callback
if (req.query.state !== req.session.oauthState) {
throw new Error('Invalid state parameter');
}
4. Implement PKCE
// Generate code verifier
const codeVerifier = crypto.randomBytes(32).toString('base64url');
// Generate code challenge
const codeChallenge = crypto
.createHash('sha256')
.update(codeVerifier)
.digest('base64url');
// Store code verifier
req.session.codeVerifier = codeVerifier;
// Include in authorization request
const authUrl = `${AUTHORIZATION_URL}?code_challenge=${codeChallenge}&code_challenge_method=S256`;
5. Secure Token Storage
// Never store tokens in localStorage or sessionStorage
// Use secure, httpOnly cookies
res.cookie('access_token', token, {
httpOnly: true,
secure: true,
sameSite: 'strict',
maxAge: 3600000,
});
6. Token Expiration
// Always set token expiration
{
"access_token": "...",
"expires_in": 3600,
"refresh_token": "..."
}
// Check expiration before use
if (Date.now() >= tokenExpiresAt) {
// Refresh token
await refreshAccessToken();
}
7. Scope Limitation
// Request only necessary scopes
const scopes = ['read:user', 'read:email']; // Don't request write access if not needed
// Validate scopes on the server
function validateScopes(requestedScopes, userGrantedScopes) {
return requestedScopes.every(scope => userGrantedScopes.includes(scope));
}
Common Vulnerabilities
1. Authorization Code Interception
Vulnerability: Attacker intercepts authorization code
Mitigation: Use PKCE (Proof Key for Code Exchange)
// Generate PKCE parameters
const codeVerifier = generateCodeVerifier();
const codeChallenge = generateCodeChallenge(codeVerifier);
// Store code_verifier securely
// Include code_challenge in authorization request
2. Redirect URI Manipulation
Vulnerability: Attacker changes redirect_uri to malicious site
Mitigation:
// Strictly validate redirect URIs
const ALLOWED_REDIRECT_URIS = [
'https://app.example.com/callback',
'https://app.example.com/oauth/callback'
];
function validateRedirectUri(uri) {
return ALLOWED_REDIRECT_URIS.includes(uri);
}
3. CSRF Attacks
Vulnerability: Attacker tricks user into authorizing their account
Mitigation:
// Always use state parameter
const state = crypto.randomBytes(16).toString('hex');
req.session.state = state;
// Verify state on callback
if (req.query.state !== req.session.state) {
throw new Error('CSRF detected');
}
4. Token Leakage
Vulnerability: Tokens exposed in URLs, logs, or browser history
Mitigation:
// Never include tokens in URLs
// ❌ Bad
window.location.href = `/api/data?token=${accessToken}`;
// ✅ Good
fetch('/api/data', {
headers: {
'Authorization': `Bearer ${accessToken}`
}
});
5. Insufficient Token Validation
Vulnerability: Server doesn't properly validate tokens
Mitigation:
async function validateToken(token) {
// 1. Verify token signature
// 2. Check expiration
// 3. Verify issuer
// 4. Verify audience
// 5. Check revocation status
if (token.exp < Date.now() / 1000) {
throw new Error('Token expired');
}
if (token.iss !== EXPECTED_ISSUER) {
throw new Error('Invalid issuer');
}
// Check if token is revoked
const isRevoked = await checkRevocationList(token.jti);
if (isRevoked) {
throw new Error('Token revoked');
}
return true;
}
Resources
Official Specifications:
Learning Resources:
Tools:
Libraries:
- Passport.js (Node.js)
- OAuth2 Server (Node.js)
- Authlib (Python)
- Spring Security OAuth (Java)
JWT (JSON Web Tokens)
JSON Web Token (JWT) is an open standard (RFC 7519) that defines a compact and self-contained way for securely transmitting information between parties as a JSON object. This information can be verified and trusted because it is digitally signed.
Table of Contents
- Introduction
- JWT Structure
- How JWT Works
- Creating and Verifying JWTs
- JWT Authentication Flow
- Refresh Tokens
- Security Best Practices
- Common Vulnerabilities
Introduction
What is JWT?
A JWT is a string of three Base64-URL encoded parts separated by dots (.), representing:
- Header
- Payload
- Signature
Example JWT:
eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c
Use Cases:
- Authentication
- Information Exchange
- Authorization
- Single Sign-On (SSO)
- Stateless API authentication
Benefits:
- Compact size
- Self-contained (all info in the token)
- Stateless (no server-side session storage)
- Cross-domain/CORS friendly
- Mobile-friendly
JWT Structure
Header
Contains the type of token and the signing algorithm.
{
"alg": "HS256",
"typ": "JWT"
}
Common Algorithms:
HS256(HMAC with SHA-256) - SymmetricRS256(RSA with SHA-256) - AsymmetricES256(ECDSA with SHA-256) - Asymmetric
Payload
Contains the claims (statements about an entity and additional data).
{
"sub": "1234567890",
"name": "John Doe",
"email": "john@example.com",
"iat": 1516239022,
"exp": 1516242622,
"roles": ["user", "admin"]
}
Registered Claims:
iss(issuer)sub(subject)aud(audience)exp(expiration time)nbf(not before)iat(issued at)jti(JWT ID)
Custom Claims: Any additional data you want to include.
Signature
Created by taking:
HMACSHA256(
base64UrlEncode(header) + "." +
base64UrlEncode(payload),
secret
)
How JWT Works
Authentication Flow
1. User logs in with credentials
↓
2. Server validates credentials
↓
3. Server creates JWT with user info
↓
4. Server sends JWT to client
↓
5. Client stores JWT (localStorage/cookie)
↓
6. Client sends JWT with each request
↓
7. Server verifies JWT signature
↓
8. Server grants/denies access
Creating and Verifying JWTs
Node.js Implementation
Installation:
npm install jsonwebtoken
Creating a JWT:
const jwt = require('jsonwebtoken');
const SECRET_KEY = process.env.JWT_SECRET;
function generateToken(user) {
const payload = {
sub: user.id,
email: user.email,
name: user.name,
roles: user.roles,
};
const options = {
expiresIn: '1h',
issuer: 'your-app-name',
audience: 'your-app-users',
};
return jwt.sign(payload, SECRET_KEY, options);
}
// Usage
const token = generateToken({
id: 123,
email: 'john@example.com',
name: 'John Doe',
roles: ['user'],
});
console.log(token);
Verifying a JWT:
function verifyToken(token) {
try {
const decoded = jwt.verify(token, SECRET_KEY, {
issuer: 'your-app-name',
audience: 'your-app-users',
});
return decoded;
} catch (error) {
if (error.name === 'TokenExpiredError') {
throw new Error('Token expired');
}
if (error.name === 'JsonWebTokenError') {
throw new Error('Invalid token');
}
throw error;
}
}
// Usage
try {
const decoded = verifyToken(token);
console.log('User:', decoded);
} catch (error) {
console.error('Verification failed:', error.message);
}
Express Middleware
const jwt = require('jsonwebtoken');
function authenticateToken(req, res, next) {
const authHeader = req.headers['authorization'];
const token = authHeader && authHeader.split(' ')[1]; // Bearer TOKEN
if (!token) {
return res.status(401).json({ error: 'Access token required' });
}
try {
const user = jwt.verify(token, process.env.JWT_SECRET);
req.user = user;
next();
} catch (error) {
return res.status(403).json({ error: 'Invalid or expired token' });
}
}
// Protected route
app.get('/api/protected', authenticateToken, (req, res) => {
res.json({
message: 'Protected data',
user: req.user,
});
});
Python Implementation (PyJWT)
import jwt
import datetime
from functools import wraps
from flask import request, jsonify
SECRET_KEY = "your-secret-key"
def generate_token(user_id, email):
payload = {
'sub': user_id,
'email': email,
'exp': datetime.datetime.utcnow() + datetime.timedelta(hours=1),
'iat': datetime.datetime.utcnow(),
'iss': 'your-app-name'
}
return jwt.encode(payload, SECRET_KEY, algorithm='HS256')
def verify_token(token):
try:
decoded = jwt.decode(
token,
SECRET_KEY,
algorithms=['HS256'],
issuer='your-app-name'
)
return decoded
except jwt.ExpiredSignatureError:
raise Exception('Token expired')
except jwt.InvalidTokenError:
raise Exception('Invalid token')
# Decorator for protected routes
def token_required(f):
@wraps(f)
def decorated(*args, **kwargs):
token = request.headers.get('Authorization')
if not token:
return jsonify({'error': 'Token missing'}), 401
try:
token = token.split(' ')[1] # Remove 'Bearer '
decoded = verify_token(token)
request.user = decoded
except Exception as e:
return jsonify({'error': str(e)}), 403
return f(*args, **kwargs)
return decorated
# Protected route
@app.route('/api/protected')
@token_required
def protected():
return jsonify({
'message': 'Protected data',
'user': request.user
})
JWT Authentication Flow
Complete Implementation
auth.js:
const express = require('express');
const jwt = require('jsonwebtoken');
const bcrypt = require('bcrypt');
const router = express.Router();
const SECRET_KEY = process.env.JWT_SECRET;
const REFRESH_SECRET = process.env.REFRESH_SECRET;
// Login
router.post('/login', async (req, res) => {
const { email, password } = req.body;
// Find user in database
const user = await User.findOne({ email });
if (!user) {
return res.status(401).json({ error: 'Invalid credentials' });
}
// Verify password
const isValidPassword = await bcrypt.compare(password, user.password);
if (!isValidPassword) {
return res.status(401).json({ error: 'Invalid credentials' });
}
// Generate tokens
const accessToken = jwt.sign(
{
sub: user.id,
email: user.email,
roles: user.roles,
},
SECRET_KEY,
{ expiresIn: '15m' }
);
const refreshToken = jwt.sign(
{ sub: user.id },
REFRESH_SECRET,
{ expiresIn: '7d' }
);
// Store refresh token in database
await RefreshToken.create({
token: refreshToken,
userId: user.id,
expiresAt: new Date(Date.now() + 7 * 24 * 60 * 60 * 1000),
});
// Send tokens
res.json({
accessToken,
refreshToken,
user: {
id: user.id,
email: user.email,
name: user.name,
},
});
});
// Refresh token
router.post('/refresh', async (req, res) => {
const { refreshToken } = req.body;
if (!refreshToken) {
return res.status(401).json({ error: 'Refresh token required' });
}
try {
// Verify refresh token
const decoded = jwt.verify(refreshToken, REFRESH_SECRET);
// Check if refresh token exists in database
const storedToken = await RefreshToken.findOne({
token: refreshToken,
userId: decoded.sub,
});
if (!storedToken) {
return res.status(403).json({ error: 'Invalid refresh token' });
}
// Get user
const user = await User.findById(decoded.sub);
// Generate new access token
const accessToken = jwt.sign(
{
sub: user.id,
email: user.email,
roles: user.roles,
},
SECRET_KEY,
{ expiresIn: '15m' }
);
res.json({ accessToken });
} catch (error) {
return res.status(403).json({ error: 'Invalid refresh token' });
}
});
// Logout
router.post('/logout', authenticateToken, async (req, res) => {
const { refreshToken } = req.body;
// Remove refresh token from database
await RefreshToken.deleteOne({
token: refreshToken,
userId: req.user.sub,
});
res.json({ message: 'Logged out successfully' });
});
module.exports = router;
Client-Side Implementation
class AuthService {
constructor() {
this.accessToken = null;
this.refreshToken = localStorage.getItem('refreshToken');
}
async login(email, password) {
const response = await fetch('/api/auth/login', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ email, password }),
});
if (!response.ok) {
throw new Error('Login failed');
}
const data = await response.json();
this.accessToken = data.accessToken;
this.refreshToken = data.refreshToken;
localStorage.setItem('refreshToken', data.refreshToken);
return data.user;
}
async refreshAccessToken() {
if (!this.refreshToken) {
throw new Error('No refresh token');
}
const response = await fetch('/api/auth/refresh', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ refreshToken: this.refreshToken }),
});
if (!response.ok) {
this.logout();
throw new Error('Token refresh failed');
}
const data = await response.json();
this.accessToken = data.accessToken;
return this.accessToken;
}
async makeAuthenticatedRequest(url, options = {}) {
if (!this.accessToken) {
await this.refreshAccessToken();
}
const response = await fetch(url, {
...options,
headers: {
...options.headers,
'Authorization': `Bearer ${this.accessToken}`,
},
});
// If token expired, refresh and retry
if (response.status === 401) {
await this.refreshAccessToken();
return fetch(url, {
...options,
headers: {
...options.headers,
'Authorization': `Bearer ${this.accessToken}`,
},
});
}
return response;
}
logout() {
this.accessToken = null;
this.refreshToken = null;
localStorage.removeItem('refreshToken');
}
}
// Usage
const auth = new AuthService();
// Login
await auth.login('user@example.com', 'password');
// Make authenticated request
const response = await auth.makeAuthenticatedRequest('/api/user/profile');
const profile = await response.json();
// Logout
auth.logout();
Refresh Tokens
Why Use Refresh Tokens?
- Short-lived access tokens reduce the window of opportunity for token theft
- Long-lived refresh tokens improve user experience (don't have to login frequently)
- Revocable - Can invalidate refresh tokens without affecting other sessions
Implementation Strategy
// Token lifetimes
const ACCESS_TOKEN_LIFETIME = '15m';
const REFRESH_TOKEN_LIFETIME = '7d';
// Store refresh tokens in database
const refreshTokenSchema = new mongoose.Schema({
token: { type: String, required: true, unique: true },
userId: { type: ObjectId, ref: 'User', required: true },
expiresAt: { type: Date, required: true },
createdAt: { type: Date, default: Date.now },
revokedAt: { type: Date },
replacedByToken: { type: String },
});
// Automatic cleanup of expired tokens
refreshTokenSchema.index({ expiresAt: 1 }, { expireAfterSeconds: 0 });
// Token rotation
async function rotateRefreshToken(oldRefreshToken) {
// Verify old token
const decoded = jwt.verify(oldRefreshToken, REFRESH_SECRET);
// Find old token in database
const oldToken = await RefreshToken.findOne({
token: oldRefreshToken,
userId: decoded.sub,
});
if (!oldToken || oldToken.revokedAt) {
throw new Error('Invalid refresh token');
}
// Create new refresh token
const newRefreshToken = jwt.sign(
{ sub: decoded.sub },
REFRESH_SECRET,
{ expiresIn: REFRESH_TOKEN_LIFETIME }
);
// Mark old token as revoked
oldToken.revokedAt = new Date();
oldToken.replacedByToken = newRefreshToken;
await oldToken.save();
// Store new token
await RefreshToken.create({
token: newRefreshToken,
userId: decoded.sub,
expiresAt: new Date(Date.now() + 7 * 24 * 60 * 60 * 1000),
});
return newRefreshToken;
}
Security Best Practices
1. Use Strong Secrets
// Generate a strong secret
const crypto = require('crypto');
const secret = crypto.randomBytes(64).toString('hex');
// Use environment variables
const SECRET_KEY = process.env.JWT_SECRET;
if (!SECRET_KEY || SECRET_KEY.length < 32) {
throw new Error('JWT_SECRET must be at least 32 characters');
}
2. Short Expiration Times
// Short-lived access tokens
const accessToken = jwt.sign(payload, SECRET_KEY, {
expiresIn: '15m', // 15 minutes
});
// Long-lived refresh tokens
const refreshToken = jwt.sign(payload, REFRESH_SECRET, {
expiresIn: '7d', // 7 days
});
3. Secure Token Storage
// ❌ Bad: localStorage (vulnerable to XSS)
localStorage.setItem('token', accessToken);
// ✅ Good: httpOnly cookie (protected from XSS)
res.cookie('access_token', accessToken, {
httpOnly: true,
secure: true,
sameSite: 'strict',
maxAge: 15 * 60 * 1000, // 15 minutes
});
// ✅ Good: Memory (for SPAs)
class TokenStore {
constructor() {
this.token = null;
}
setToken(token) {
this.token = token;
}
getToken() {
return this.token;
}
clearToken() {
this.token = null;
}
}
4. Validate All Claims
function validateToken(token) {
const decoded = jwt.verify(token, SECRET_KEY, {
issuer: 'your-app',
audience: 'your-users',
});
// Check expiration
if (decoded.exp < Date.now() / 1000) {
throw new Error('Token expired');
}
// Check not before
if (decoded.nbf && decoded.nbf > Date.now() / 1000) {
throw new Error('Token not yet valid');
}
// Validate custom claims
if (!decoded.roles || !Array.isArray(decoded.roles)) {
throw new Error('Invalid token structure');
}
return decoded;
}
5. Use Asymmetric Algorithms for Distributed Systems
const fs = require('fs');
// Generate RSA key pair
const { generateKeyPairSync } = require('crypto');
const { privateKey, publicKey } = generateKeyPairSync('rsa', {
modulusLength: 2048,
});
// Sign with private key
const token = jwt.sign(payload, privateKey, {
algorithm: 'RS256',
expiresIn: '1h',
});
// Verify with public key (can be shared with other services)
const decoded = jwt.verify(token, publicKey, {
algorithms: ['RS256'],
});
6. Implement Token Blacklist for Logout
const blacklist = new Set();
async function logout(token) {
const decoded = jwt.decode(token);
// Add to blacklist with expiration
await redis.setex(
`blacklist:${decoded.jti}`,
decoded.exp - Math.floor(Date.now() / 1000),
'true'
);
}
async function isTokenBlacklisted(token) {
const decoded = jwt.decode(token);
const isBlacklisted = await redis.exists(`blacklist:${decoded.jti}`);
return isBlacklisted === 1;
}
// Middleware
async function authenticateToken(req, res, next) {
const token = extractToken(req);
if (await isTokenBlacklisted(token)) {
return res.status(401).json({ error: 'Token has been revoked' });
}
// Verify token...
next();
}
Common Vulnerabilities
1. Algorithm Confusion Attack
Vulnerability: Attacker changes algorithm from RS256 to HS256 and uses public key as secret
Mitigation:
// Always specify allowed algorithms
jwt.verify(token, secret, {
algorithms: ['RS256'], // Only allow RS256
});
// Never use 'none' algorithm
jwt.sign(payload, secret, {
algorithm: 'HS256', // Specify algorithm explicitly
});
2. Weak Secret Keys
Vulnerability: Short or predictable secrets can be brute-forced
Mitigation:
// Use strong, random secrets (at least 256 bits)
const crypto = require('crypto');
const secret = crypto.randomBytes(32).toString('hex');
// Store in environment variables
const SECRET_KEY = process.env.JWT_SECRET;
// Validate secret strength
if (SECRET_KEY.length < 32) {
throw new Error('Secret key too short');
}
3. Token Leakage in URLs
Vulnerability: Tokens in URL parameters are logged and visible
Mitigation:
// ❌ Bad: Token in URL
fetch(`/api/data?token=${accessToken}`);
// ✅ Good: Token in header
fetch('/api/data', {
headers: {
'Authorization': `Bearer ${accessToken}`,
},
});
4. Missing Expiration
Vulnerability: Tokens without expiration never expire
Mitigation:
// Always set expiration
const token = jwt.sign(payload, secret, {
expiresIn: '15m',
});
// Verify expiration on the server
jwt.verify(token, secret, {
clockTolerance: 0, // No tolerance for expired tokens
});
5. XSS Attacks
Vulnerability: Tokens stored in localStorage can be stolen via XSS
Mitigation:
// Use httpOnly cookies
res.cookie('token', token, {
httpOnly: true,
secure: true,
sameSite: 'strict',
});
// Or store in memory (for SPAs)
// Never use localStorage or sessionStorage for sensitive tokens
6. Insufficient Token Validation
Vulnerability: Not validating all claims or checking token blacklist
Mitigation:
async function validateToken(token) {
// 1. Verify signature
const decoded = jwt.verify(token, SECRET_KEY);
// 2. Check blacklist
if (await isBlacklisted(decoded.jti)) {
throw new Error('Token revoked');
}
// 3. Validate issuer
if (decoded.iss !== EXPECTED_ISSUER) {
throw new Error('Invalid issuer');
}
// 4. Validate audience
if (decoded.aud !== EXPECTED_AUDIENCE) {
throw new Error('Invalid audience');
}
// 5. Additional business logic checks
const user = await User.findById(decoded.sub);
if (!user || !user.isActive) {
throw new Error('User not found or inactive');
}
return decoded;
}
Resources
Official Specifications:
Libraries:
- jsonwebtoken (Node.js)
- PyJWT (Python)
- jose (Node.js, modern)
- java-jwt (Java)
Tools:
Learning Resources:
Security:
Wifi
Welcome to the Wifi section of the notes. Here, you will find comprehensive information about various aspects of Wifi technology, including its basics, standards, security, and more.
Table of Contents
Explore the topics to gain a deeper understanding of Wifi technology and its applications.
Wifi Basics
Aggregation
Aggregation in Wi-Fi refers to the process of combining multiple data frames into a single transmission unit. This technique is used to improve the efficiency and throughput of wireless networks by reducing the overhead associated with each individual frame transmission. There are two main types of aggregation in Wi-Fi:
-
A-MPDU (Aggregated MAC Protocol Data Unit):
- Combines multiple MAC frames into a single PHY (Physical Layer) frame.
- Reduces the inter-frame spacing and acknowledgment overhead.
- Improves throughput by allowing multiple frames to be sent in a single transmission burst.
-
A-MSDU (Aggregated MAC Service Data Unit):
- Combines multiple MSDUs (MAC Service Data Units) into a single MPDU (MAC Protocol Data Unit).
- Reduces the overhead by aggregating data at the MAC layer before it is passed to the PHY layer.
- Increases efficiency by reducing the number of headers and acknowledgments required.
Both A-MPDU and A-MSDU are supported in 802.11n and later standards, such as 802.11ac and 802.11ax. These aggregation techniques are particularly beneficial in high-throughput and high-density environments, where they help to maximize the use of available bandwidth and improve overall network performance.
Wifi Bands
2.4 GHz
- 802.11a
- 802.11b
- 802.11g
The 2.4 GHz band is one of the most commonly used frequency bands for Wi-Fi communication. It is known for its longer range and better penetration through obstacles such as walls and floors. However, it is also more susceptible to interference from other devices, such as microwaves, cordless phones, and Bluetooth devices, which operate in the same frequency range.
Channels in 2.4 GHz Band
The 2.4 GHz band is divided into multiple channels, each with a specific frequency range. The channels are spaced 5 MHz apart, but due to the width of the channels (22 MHz), there is significant overlap between adjacent channels. This can lead to interference if multiple networks are operating on overlapping channels. The commonly used channels in the 2.4 GHz band are:
- Channel 1: 2.412 GHz
- Channel 2: 2.417 GHz
- Channel 3: 2.422 GHz
- Channel 4: 2.427 GHz
- Channel 5: 2.432 GHz
- Channel 6: 2.437 GHz
- Channel 7: 2.442 GHz
- Channel 8: 2.447 GHz
- Channel 9: 2.452 GHz
- Channel 10: 2.457 GHz
- Channel 11: 2.462 GHz
In some regions, additional channels are available:
- Channel 12: 2.467 GHz
- Channel 13: 2.472 GHz
- Channel 14: 2.484 GHz (only available in Japan)
To minimize interference, it is recommended to use non-overlapping channels. In the 2.4 GHz band, the non-overlapping channels are typically channels 1, 6, and 11. By configuring Wi-Fi networks to operate on these channels, interference can be reduced, leading to improved performance and reliability.
5 GHz
- 802.11a
- 802.11n
- 802.11ac
- 802.11ax
Channels in 5 GHz Band
The 5 GHz band offers a larger number of channels compared to the 2.4 GHz band, which helps to reduce interference and congestion. The channels in the 5 GHz band are spaced 20 MHz apart, and there are several non-overlapping channels available. This band is divided into several sub-bands, each with its own set of channels:
-
UNII-1 (5150-5250 MHz):
- Channel 36: 5.180 GHz
- Channel 40: 5.200 GHz
- Channel 44: 5.220 GHz
- Channel 48: 5.240 GHz
-
UNII-2 (5250-5350 MHz):
- Channel 52: 5.260 GHz
- Channel 56: 5.280 GHz
- Channel 60: 5.300 GHz
- Channel 64: 5.320 GHz
-
UNII-2 Extended (5470-5725 MHz):
- Channel 100: 5.500 GHz
- Channel 104: 5.520 GHz
- Channel 108: 5.540 GHz
- Channel 112: 5.560 GHz
- Channel 116: 5.580 GHz
- Channel 120: 5.600 GHz
- Channel 124: 5.620 GHz
- Channel 128: 5.640 GHz
- Channel 132: 5.660 GHz
- Channel 136: 5.680 GHz
- Channel 140: 5.700 GHz
- Channel 144: 5.720 GHz
-
UNII-3 (5725-5850 MHz):
- Channel 149: 5.745 GHz
- Channel 153: 5.765 GHz
- Channel 157: 5.785 GHz
- Channel 161: 5.805 GHz
- Channel 165: 5.825 GHz
The 5 GHz band is less crowded than the 2.4 GHz band and offers higher data rates and lower latency. However, it has a shorter range and less ability to penetrate obstacles such as walls and floors. The use of non-overlapping channels in the 5 GHz band helps to minimize interference and improve overall network performance. Additionally, Dynamic Frequency Selection (DFS) is used in some channels to avoid interference with radar systems.
6 GHz
- 802.11ax
- 802.11be
Channels in 6 GHz Band
The 6 GHz band is a new addition to the Wi-Fi spectrum, providing even more channels and bandwidth for wireless communication. This band is divided into several sub-bands, each with its own set of channels. The channels in the 6 GHz band are spaced 20 MHz apart, similar to the 5 GHz band, and there are numerous non-overlapping channels available. The 6 GHz band offers higher data rates, lower latency, and reduced interference compared to the 2.4 GHz and 5 GHz bands.
-
UNII-5 (5925-6425 MHz):
- Channel 1: 5.925 GHz
- Channel 5: 5.945 GHz
- Channel 9: 5.965 GHz
- Channel 13: 5.985 GHz
- Channel 17: 6.005 GHz
- Channel 21: 6.025 GHz
- Channel 25: 6.045 GHz
- Channel 29: 6.065 GHz
- Channel 33: 6.085 GHz
- Channel 37: 6.105 GHz
- Channel 41: 6.125 GHz
- Channel 45: 6.145 GHz
- Channel 49: 6.165 GHz
- Channel 53: 6.185 GHz
- Channel 57: 6.205 GHz
- Channel 61: 6.225 GHz
- Channel 65: 6.245 GHz
- Channel 69: 6.265 GHz
- Channel 73: 6.285 GHz
- Channel 77: 6.305 GHz
- Channel 81: 6.325 GHz
- Channel 85: 6.345 GHz
- Channel 89: 6.365 GHz
- Channel 93: 6.385 GHz
- Channel 97: 6.405 GHz
- Channel 101: 6.425 GHz
-
UNII-6 (6425-6525 MHz):
- Channel 105: 6.445 GHz
- Channel 109: 6.465 GHz
- Channel 113: 6.485 GHz
- Channel 117: 6.505 GHz
- Channel 121: 6.525 GHz
-
UNII-7 (6525-6875 MHz):
- Channel 125: 6.545 GHz
- Channel 129: 6.565 GHz
- Channel 133: 6.585 GHz
- Channel 137: 6.605 GHz
- Channel 141: 6.625 GHz
- Channel 145: 6.645 GHz
- Channel 149: 6.665 GHz
- Channel 153: 6.685 GHz
- Channel 157: 6.705 GHz
- Channel 161: 6.725 GHz
- Channel 165: 6.745 GHz
- Channel 169: 6.765 GHz
- Channel 173: 6.785 GHz
- Channel 177: 6.805 GHz
- Channel 181: 6.825 GHz
- Channel 185: 6.845 GHz
- Channel 189: 6.865 GHz
- Channel 193: 6.885 GHz
- Channel 197: 6.905 GHz
- Channel 201: 6.925 GHz
- Channel 205: 6.945 GHz
- Channel 209: 6.965 GHz
- Channel 213: 6.985 GHz
-
UNII-8 (6875-7125 MHz):
- Channel 217: 7.005 GHz
- Channel 221: 7.025 GHz
- Channel 225: 7.045 GHz
- Channel 229: 7.065 GHz
- Channel 233: 7.085 GHz
- Channel 237: 7.105 GHz
- Channel 241: 7.125 GHz
The 6 GHz band is expected to significantly enhance Wi-Fi performance, especially in dense environments, by providing more spectrum and reducing congestion. Devices that support the 6 GHz band can take advantage of these additional channels to achieve faster speeds and more reliable connections.
Wifi channel width
Wi-Fi channel width refers to the size of the frequency band that a Wi-Fi signal occupies. The channel width determines the data rate and the amount of data that can be transmitted over the network. Wider channels can carry more data, but they are also more susceptible to interference and congestion. The most common channel widths in Wi-Fi are 20 MHz, 40 MHz, 80 MHz, and 160 MHz.
20 MHz Channels
20 MHz is the standard channel width for Wi-Fi and is widely used in both 2.4 GHz and 5 GHz bands. It provides a good balance between range and throughput. A 20 MHz channel is less likely to experience interference from other devices and networks, making it a reliable choice for most applications.
40 MHz Channels
40 MHz channels are used to increase the data rate by bonding two adjacent 20 MHz channels. This effectively doubles the bandwidth, allowing for higher throughput. However, 40 MHz channels are more prone to interference, especially in the crowded 2.4 GHz band. In the 5 GHz band, 40 MHz channels are more practical due to the availability of more non-overlapping channels.
80 MHz Channels
80 MHz channels further increase the data rate by bonding four adjacent 20 MHz channels. This provides even higher throughput, making it suitable for applications that require high data rates, such as HD video streaming and online gaming. However, 80 MHz channels are more susceptible to interference and are typically used in the 5 GHz and 6 GHz bands where more spectrum is available.
160 MHz Channels
160 MHz channels offer the highest data rates by bonding eight adjacent 20 MHz channels. This channel width is ideal for applications that demand extremely high throughput, such as virtual reality (VR) and large file transfers. However, 160 MHz channels are highly susceptible to interference and are only practical in the 5 GHz and 6 GHz bands with sufficient spectrum availability.
Channel Width Selection
The choice of channel width depends on the specific requirements of the network and the environment. In dense environments with many Wi-Fi networks, narrower channels (20 MHz or 40 MHz) are preferred to minimize interference. In less congested environments, wider channels (80 MHz or 160 MHz) can be used to achieve higher data rates.
Impact on Performance
Wider channels can significantly improve Wi-Fi performance by increasing the data rate and reducing latency. However, they also require more spectrum and are more vulnerable to interference. It is essential to balance the need for higher throughput with the potential for increased interference when selecting the appropriate channel width for a Wi-Fi network.
In summary, Wi-Fi channel width plays a crucial role in determining the performance and reliability of a wireless network. Understanding the trade-offs between different channel widths can help optimize the network for specific applications and environments.
Identifying Channel Width from Beacon Frames
To identify the channel width from Wi-Fi beacon frames, you need to analyze the information elements (IEs) within the beacon frame. Beacon frames are periodically transmitted by access points (APs) to announce the presence of a Wi-Fi network. These frames contain various IEs that provide information about the network, including the channel width.
Steps to Identify Channel Width
-
Capture Beacon Frames: Use a Wi-Fi packet capture tool (e.g., Wireshark) to capture beacon frames from the Wi-Fi network. Ensure that your capture device supports the frequency bands and channel widths used by the network.
-
Locate the HT Capabilities IE: In the captured beacon frame, locate the "HT Capabilities" information element. This IE is present in 802.11n and later standards and provides information about the supported channel widths.
-
Check Supported Channel Widths: Within the HT Capabilities IE, look for the "Supported Channel Width Set" field. This field indicates whether the AP supports 20 MHz, 40 MHz, or both channel widths. The field is typically represented as:
0: 20 MHz only1: 20 MHz and 40 MHz
-
Locate the VHT Capabilities IE: For 802.11ac networks, locate the "VHT Capabilities" information element. This IE provides information about the supported channel widths for very high throughput (VHT) networks.
-
Check VHT Supported Channel Widths: Within the VHT Capabilities IE, look for the "Supported Channel Width Set" field. This field indicates whether the AP supports 20 MHz, 40 MHz, 80 MHz, or 160 MHz channel widths. The field is typically represented as:
0: 20 MHz and 40 MHz1: 80 MHz2: 160 MHz and 80+80 MHz
-
Analyze HE Capabilities IE: For 802.11ax (Wi-Fi 6) networks, locate the "HE Capabilities" information element. This IE provides information about the supported channel widths for high-efficiency (HE) networks.
-
Check HE Supported Channel Widths: Within the HE Capabilities IE, look for the "Supported Channel Width Set" field. This field indicates whether the AP supports 20 MHz, 40 MHz, 80 MHz, 160 MHz, or 80+80 MHz channel widths.
Example
Here is an example of how to identify the channel width from a beacon frame using Wireshark:
- Open Wireshark and start capturing packets on the desired Wi-Fi interface.
- Filter the captured packets to display only beacon frames using the filter:
wlan.fc.type_subtype == 0x08. - Select a beacon frame from the list and expand the "IEEE 802.11 wireless LAN management frame" section.
- Locate the "HT Capabilities" IE and check the "Supported Channel Width Set" field.
- If applicable, locate the "VHT Capabilities" IE and check the "Supported Channel Width Set" field.
- If applicable, locate the "HE Capabilities" IE and check the "Supported Channel Width Set" field.
By following these steps, you can determine the channel width supported by the Wi-Fi network from the beacon frames.
Tools
- Wireshark: A popular network protocol analyzer that can capture and analyze Wi-Fi packets, including beacon frames.
- Aircrack-ng: A suite of tools for capturing and analyzing Wi-Fi packets, including airodump-ng for capturing beacon frames.
Understanding the channel width from beacon frames can help optimize Wi-Fi network performance and troubleshoot connectivity issues. By analyzing the beacon frames, you can gain insights into the network's capabilities and configuration.
Types of Frames in Wi-Fi
Wi-Fi communication relies on the exchange of various types of frames between devices. These frames are categorized into three main types: management frames, control frames, and data frames. Each type of frame serves a specific purpose in the operation and maintenance of the Wi-Fi network.
-
Management Frames: Management frames are used to establish and maintain connections between devices in a Wi-Fi network. They facilitate the discovery, authentication, and association processes. Common types of management frames include:
- Beacon Frames: Broadcasted periodically by access points (APs) to announce the presence and capabilities of the network.
- Probe Request Frames: Sent by clients to discover available networks.
- Probe Response Frames: Sent by APs in response to probe requests, providing information about the network.
- Authentication Frames: Used to initiate the authentication process between a client and an AP.
- Deauthentication Frames: Used to terminate an existing authentication.
- Association Request Frames: Sent by clients to request association with an AP.
- Association Response Frames: Sent by APs in response to association requests, indicating acceptance or rejection.
- Disassociation Frames: Used to terminate an existing association.
-
Control Frames: Control frames assist in the delivery of data frames and help manage access to the wireless medium. They ensure that data frames are transmitted efficiently and without collisions. Common types of control frames include:
- Request to Send (RTS) Frames: Used to request permission to send data, helping to avoid collisions in a busy network.
- Clear to Send (CTS) Frames: Sent in response to RTS frames, granting permission to send data.
- Acknowledgment (ACK) Frames: Sent to confirm the successful receipt of data frames.
- Power Save Poll (PS-Poll) Frames: Used by clients in power-saving mode to request buffered data from the AP.
-
Data Frames: Data frames carry the actual data payload between devices in a Wi-Fi network. They are used for the transmission of user data, such as web pages, emails, and file transfers. Data frames can also include additional information, such as quality of service (QoS) parameters, to prioritize certain types of traffic. Common types of data frames include:
- Data Frames: Carry user data between devices.
- Null Data Frames: Used for power management, indicating that a device is awake or entering sleep mode.
- QoS Data Frames: Include QoS parameters to prioritize certain types of traffic, such as voice or video.
Understanding the different types of frames in Wi-Fi is essential for analyzing and troubleshooting wireless networks. Each frame type plays a crucial role in the overall operation and performance of the network, ensuring reliable and efficient communication between devices.
Wifi Standards
802.11
802.11a
- Released: 1999
- Frequency: 5 GHz
- Maximum Speed: 54 Mbps
- Notes: First standard to use OFDM (Orthogonal Frequency Division Multiplexing).
802.11b
- Released: 1999
- Frequency: 2.4 GHz
- Maximum Speed: 11 Mbps
- Notes: Uses DSSS (Direct Sequence Spread Spectrum) modulation.
802.11g
- Released: 2003
- Frequency: 2.4 GHz
- Maximum Speed: 54 Mbps
- Notes: Backward compatible with 802.11b, uses OFDM.
802.11n
- Released: 2009
- Frequency: 2.4 GHz and 5 GHz
- Maximum Speed: 600 Mbps
- Notes: Introduced MIMO (Multiple Input Multiple Output) technology.
802.11ac
- Released: 2013
- Frequency: 5 GHz
- Maximum Speed: 1.3 Gbps
- Notes: Uses wider channels (80 or 160 MHz) and more spatial streams.
802.11ax
- Released: 2019
- Frequency: 2.4 GHz and 5 GHz
- Maximum Speed: 9.6 Gbps
- Notes: Also known as Wi-Fi 6, introduces OFDMA (Orthogonal Frequency Division Multiple Access) and improved efficiency in dense environments.
802.11be
- Released: 2024
- Frequency: 6 GHz
- Maximum Speed: 48 Gbps
- Notes: Also known as Wi-Fi 7, introduces EHT (Extremely High Throughput) technology.
Wi-Fi Security
Wi-Fi security is crucial for protecting wireless networks from unauthorized access and ensuring the confidentiality and integrity of data transmitted over the air. There are several wireless security protocols and mechanisms that have been developed over the years to enhance the security of Wi-Fi networks. Here are some of the most common wireless security protocols:
WEP (Wired Equivalent Privacy)
- Introduced: 1997
- Encryption: RC4 stream cipher
- Key Length: 40-bit or 104-bit
- Notes: WEP was the first security protocol for Wi-Fi networks, designed to provide a level of security comparable to that of a wired network. However, it has significant vulnerabilities and is considered insecure. It is no longer recommended for use.
WPA (Wi-Fi Protected Access)
- Introduced: 2003
- Encryption: TKIP (Temporal Key Integrity Protocol)
- Key Length: 128-bit
- Notes: WPA was introduced as an interim solution to address the weaknesses of WEP. It uses TKIP to improve encryption and includes mechanisms for key management and integrity checking. While more secure than WEP, WPA has been largely replaced by WPA2.
WPA2 (Wi-Fi Protected Access II)
- Introduced: 2004
- Encryption: AES (Advanced Encryption Standard)
- Key Length: 128-bit
- Notes: WPA2 is the most widely used Wi-Fi security protocol today. It uses AES for encryption, which is considered highly secure. WPA2 also includes support for CCMP (Counter Mode with Cipher Block Chaining Message Authentication Code Protocol) for data integrity and confidentiality. It is recommended for all modern Wi-Fi networks.
Technical Details
WPA2 operates in two modes: Personal (WPA2-PSK) and Enterprise (WPA2-Enterprise).
-
WPA2-Personal (Pre-Shared Key - PSK):
- Uses a pre-shared key for authentication.
- Suitable for home and small office networks.
- The pre-shared key is used to derive the Pairwise Transient Key (PTK), which is used for encrypting data between the client and the access point.
-
WPA2-Enterprise:
- Uses 802.1X authentication with an external RADIUS server.
- Suitable for enterprise and large networks.
- Provides individual authentication credentials for each user.
- Supports various Extensible Authentication Protocol (EAP) methods, such as EAP-TLS, EAP-TTLS, and PEAP.
Key Management
WPA2 uses a robust key management framework to ensure secure communication:
- Pairwise Master Key (PMK): Derived from the pre-shared key (PSK) in WPA2-Personal or obtained through 802.1X authentication in WPA2-Enterprise.
- Pairwise Transient Key (PTK): Derived from the PMK, the client MAC address, the access point MAC address, and nonces exchanged during the 4-way handshake. The PTK is used to encrypt unicast traffic between the client and the access point.
- Group Temporal Key (GTK): Used to encrypt broadcast and multicast traffic. The GTK is generated by the access point and distributed to clients during the 4-way handshake.
4-Way Handshake
The 4-way handshake is a crucial process in WPA2 that ensures the secure exchange of encryption keys between the client and the access point:
- Message 1: The access point sends an ANonce (a random number) to the client.
- Message 2: The client generates an SNonce (another random number) and uses it, along with the ANonce, to derive the PTK. The client then sends the SNonce to the access point.
- Message 3: The access point uses the SNonce and ANonce to derive the PTK. It then sends the GTK (encrypted with the PTK) and a message integrity code (MIC) to the client.
- Message 4: The client sends an acknowledgment to the access point, indicating that it has successfully installed the PTK and GTK.
Authentication and Key Management (AKM) Suites
WPA2 supports various AKM suites to provide flexibility in authentication methods:
- PSK (Pre-Shared Key): Used in WPA2-Personal for simple passphrase-based authentication.
- 802.1X: Used in WPA2-Enterprise for authentication with a RADIUS server.
- FT (Fast Transition): Also known as 802.11r, used to enable fast roaming between access points without re-authentication.
- SAE (Simultaneous Authentication of Equals): Introduced in WPA3 but can be used in WPA2 for enhanced security.
Frames in WPA2
WPA2 uses several types of frames to manage security and encryption:
- Authentication Frames: Used to initiate the authentication process between the client and the access point.
- Association Frames: Used to establish a connection between the client and the access point.
- EAPOL (Extensible Authentication Protocol over LAN) Frames: Used during the 4-way handshake to exchange nonces and encryption keys.
- Data Frames: Encrypted using the PTK for unicast traffic and the GTK for broadcast/multicast traffic.
By understanding the technical details and mechanisms of WPA2, users and network administrators can ensure robust security for their Wi-Fi networks, protecting against unauthorized access and ensuring the confidentiality and integrity of their data.
WPA3 (Wi-Fi Protected Access III)
- Introduced: 2018
- Encryption: AES with GCMP (Galois/Counter Mode Protocol)
- Key Length: 128-bit or 192-bit
- Notes: WPA3 is the latest Wi-Fi security protocol, designed to provide enhanced security features over WPA2. It includes improvements such as Simultaneous Authentication of Equals (SAE) for stronger password-based authentication, forward secrecy to protect data even if a key is compromised, and improved protection against brute-force attacks. WPA3 is recommended for new Wi-Fi networks and devices.
Key Management
WPA3 introduces a more robust key management framework to enhance security:
- Simultaneous Authentication of Equals (SAE): A secure key establishment protocol that replaces the pre-shared key (PSK) method used in WPA2-Personal. SAE provides protection against offline dictionary attacks and ensures forward secrecy.
- Pairwise Master Key (PMK): Derived from the SAE process in WPA3-Personal or obtained through 802.1X authentication in WPA3-Enterprise.
- Pairwise Transient Key (PTK): Derived from the PMK, the client MAC address, the access point MAC address, and nonces exchanged during the 4-way handshake. The PTK is used to encrypt unicast traffic between the client and the access point.
- Group Temporal Key (GTK): Used to encrypt broadcast and multicast traffic. The GTK is generated by the access point and distributed to clients during the 4-way handshake.
4-Way Handshake
The 4-way handshake in WPA3 is similar to WPA2 but includes enhancements for improved security:
- Message 1: The access point sends an ANonce (a random number) to the client.
- Message 2: The client generates an SNonce (another random number) and uses it, along with the ANonce, to derive the PTK. The client then sends the SNonce to the access point.
- Message 3: The access point uses the SNonce and ANonce to derive the PTK. It then sends the GTK (encrypted with the PTK) and a message integrity code (MIC) to the client.
- Message 4: The client sends an acknowledgment to the access point, indicating that it has successfully installed the PTK and GTK.
Authentication and Key Management (AKM) Suites
WPA3 supports various AKM suites to provide flexibility in authentication methods:
- SAE (Simultaneous Authentication of Equals): Used in WPA3-Personal for secure password-based authentication.
- 802.1X: Used in WPA3-Enterprise for authentication with a RADIUS server.
- Suite B: A set of cryptographic algorithms approved by the National Security Agency (NSA) for use in high-security environments. Suite B includes support for 192-bit encryption keys and elliptic curve cryptography (ECC).
Frames in WPA3
WPA3 uses several types of frames to manage security and encryption:
- Authentication Frames: Used to initiate the authentication process between the client and the access point.
- Association Frames: Used to establish a connection between the client and the access point.
- EAPOL (Extensible Authentication Protocol over LAN) Frames: Used during the 4-way handshake to exchange nonces and encryption keys.
- Data Frames: Encrypted using the PTK for unicast traffic and the GTK for broadcast/multicast traffic.
By understanding the technical details and mechanisms of WPA3, users and network administrators can ensure robust security for their Wi-Fi networks, protecting against unauthorized access and ensuring the confidentiality and integrity of their data.
802.1X (Port-Based Network Access Control)
- Introduced: 2001
- Authentication: EAP (Extensible Authentication Protocol)
- Notes: 802.1X is a network access control protocol that provides an authentication framework for wired and wireless networks. It is commonly used in enterprise environments to authenticate users and devices before granting access to the network. 802.1X can be used in conjunction with WPA2 and WPA3 for enhanced security.
Technical Details
802.1X operates at the data link layer (Layer 2) of the OSI model and uses the Extensible Authentication Protocol (EAP) to facilitate authentication. The protocol involves three main components:
- Supplicant: The device (e.g., a laptop or smartphone) that requests access to the network.
- Authenticator: The network device (e.g., a switch or wireless access point) that controls access to the network.
- Authentication Server: The server (e.g., a RADIUS server) that validates the credentials of the supplicant.
Authentication Process
The 802.1X authentication process involves the following steps:
- Initialization: The supplicant connects to the network and the authenticator detects the connection.
- EAPOL-Start: The supplicant sends an EAPOL-Start frame to the authenticator to initiate the authentication process.
- EAP-Request/Identity: The authenticator responds with an EAP-Request/Identity frame, asking the supplicant for its identity.
- EAP-Response/Identity: The supplicant replies with an EAP-Response/Identity frame, providing its identity to the authenticator.
- RADIUS Access-Request: The authenticator forwards the identity information to the authentication server in a RADIUS Access-Request message.
- RADIUS Access-Challenge: The authentication server may respond with a RADIUS Access-Challenge message, requesting additional information (e.g., a password or token).
- EAP-Request/Challenge: The authenticator forwards the challenge to the supplicant in an EAP-Request/Challenge frame.
- EAP-Response/Challenge: The supplicant responds with the requested information in an EAP-Response/Challenge frame.
- RADIUS Access-Accept: If the authentication server successfully validates the credentials, it sends a RADIUS Access-Accept message to the authenticator.
- EAP-Success: The authenticator informs the supplicant of successful authentication with an EAP-Success frame.
- Port Authorization: The authenticator grants access to the network by opening the port for the supplicant.
Frames in 802.1X
802.1X uses several types of frames to manage the authentication process:
-
EAPOL (Extensible Authentication Protocol over LAN) Frames: Used for communication between the supplicant and the authenticator.
- EAPOL-Start: Initiates the authentication process.
- EAPOL-Logoff: Terminates the authentication session.
- EAPOL-Key: Used for key management in WPA/WPA2/WPA3.
- EAPOL-Packet: Carries EAP messages between the supplicant and the authenticator.
-
EAP (Extensible Authentication Protocol) Frames: Used for communication between the supplicant and the authentication server.
- EAP-Request: Sent by the authenticator to request information from the supplicant.
- EAP-Response: Sent by the supplicant in response to an EAP-Request.
- EAP-Success: Indicates successful authentication.
- EAP-Failure: Indicates failed authentication.
Authentication and Key Management (AKM) Suites
802.1X supports various AKM suites to provide flexibility in authentication methods:
- EAP-TLS (Transport Layer Security): Uses client and server certificates for mutual authentication.
- EAP-TTLS (Tunneled Transport Layer Security): Establishes a secure tunnel using server certificates, then uses another authentication method (e.g., PAP, CHAP) within the tunnel.
- PEAP (Protected Extensible Authentication Protocol): Similar to EAP-TTLS, but uses a different method for establishing the secure tunnel.
- EAP-MSCHAPv2 (Microsoft Challenge Handshake Authentication Protocol version 2): Uses a password-based authentication mechanism.
- EAP-SIM (Subscriber Identity Module): Uses the SIM card in mobile devices for authentication.
By understanding the technical details and mechanisms of 802.1X, users and network administrators can ensure robust security for their wired and wireless networks, protecting against unauthorized access and ensuring the confidentiality and integrity of their data.
WPS (Wi-Fi Protected Setup)
- Introduced: 2007
- Notes: WPS is a network security standard designed to simplify the process of connecting devices to a Wi-Fi network. It allows users to connect to a network by pressing a button on the router or entering a PIN. However, WPS has known vulnerabilities and can be exploited by attackers to gain unauthorized access to the network. It is recommended to disable WPS if security is a concern.
MAC Address Filtering
- Notes: MAC address filtering is a security measure that allows only devices with specific MAC addresses to connect to the Wi-Fi network. While it can provide an additional layer of security, it is not foolproof, as MAC addresses can be spoofed by attackers. It should be used in conjunction with other security measures.
Guest Networks
- Notes: Many modern routers support the creation of guest networks, which provide a separate Wi-Fi network for visitors. Guest networks can be isolated from the main network, preventing guests from accessing sensitive resources. This is a useful feature for enhancing security in both home and business environments.
By understanding and implementing these Wi-Fi security protocols and mechanisms, users and network administrators can protect their wireless networks from unauthorized access and ensure the confidentiality and integrity of their data.
Wi-Fi Scanning
Wi-Fi scanning is the process of identifying available wireless networks within range of a Wi-Fi-enabled device. This process is essential for connecting to Wi-Fi networks, troubleshooting connectivity issues, and optimizing network performance. Wi-Fi scanning can be performed using various tools and techniques, and it typically involves the following steps:
-
Initiate Scan: The Wi-Fi-enabled device sends out probe request frames to discover available networks. These frames are broadcasted on different channels to ensure that all nearby networks are detected.
-
Receive Probe Responses: Access points (APs) within range respond to the probe request frames with probe response frames. These frames contain information about the network, such as the Service Set Identifier (SSID), supported data rates, security protocols, and other capabilities.
-
Analyze Beacon Frames: In addition to probe responses, the device can also listen for beacon frames that are periodically broadcasted by APs. Beacon frames contain similar information to probe responses and help the device identify available networks.
-
Compile Network List: The device compiles a list of available networks based on the received probe responses and beacon frames. This list includes details such as the SSID, signal strength (RSSI), channel, and security type of each network.
-
Select Network: The user or device selects a network from the list to connect to. The selection can be based on various factors, such as signal strength, network name, or security requirements.
Tools for Wi-Fi Scanning
Several tools and utilities can be used for Wi-Fi scanning, including:
- Wireshark: A network protocol analyzer that can capture and analyze Wi-Fi packets, including probe requests, probe responses, and beacon frames.
- NetSpot: A Wi-Fi survey and analysis tool that provides detailed information about available networks, including signal strength, channel usage, and security settings.
- inSSIDer: A Wi-Fi scanner that displays information about nearby networks, such as SSID, signal strength, channel, and security type.
- Acrylic Wi-Fi: A Wi-Fi scanner and analyzer that provides real-time information about available networks, including signal strength, channel usage, and network performance metrics.
Importance of Wi-Fi Scanning
Wi-Fi scanning is crucial for several reasons:
- Network Discovery: It allows users to discover available networks and choose the best one to connect to.
- Troubleshooting: It helps identify connectivity issues, such as weak signals, interference, or misconfigured settings.
- Optimization: It provides insights into network performance and helps optimize the configuration, such as selecting the best channel to minimize interference.
- Security: It helps identify unauthorized or rogue access points that may pose a security threat to the network.
By understanding and utilizing Wi-Fi scanning techniques, users and network administrators can ensure reliable and efficient wireless connectivity.
Roaming
Overview
WiFi roaming is the process by which a client device (such as a smartphone or laptop) seamlessly transitions from one access point (AP) to another within the same network without losing connectivity. This is essential for maintaining uninterrupted service in environments with multiple APs, such as offices, campuses, and large homes.
The Roaming Process
When a client device roams, several steps occur:
- Discovery: The client scans for available APs and measures their signal strength (RSSI - Received Signal Strength Indicator).
- Decision: Based on signal strength, network load, and other factors, the client decides to roam to a different AP.
- Authentication: The client authenticates with the new AP.
- Reassociation: The client reassociates with the new AP, completing the handoff.
- Key Exchange: Security keys are exchanged to establish a secure connection.
Legacy Roaming Challenges
Traditional roaming (without 802.11r/k/v/w) has several challenges:
- High Latency: Full 802.1X authentication can take 50-100ms or more, disrupting real-time applications like VoIP.
- Poor Decision Making: Clients lack information about neighboring APs and may make suboptimal roaming decisions.
- Security Vulnerabilities: Management frames are unprotected, allowing deauthentication attacks.
- Inefficient Scanning: Clients must actively scan all channels to discover APs, wasting time and battery.
The 802.11r, 802.11k, 802.11v, and 802.11w standards address these challenges by introducing fast transitions, radio resource management, network management, and security enhancements.
Basic Roaming Flow Diagram
sequenceDiagram
participant Client
participant AP1 as Current AP
participant AP2 as Target AP
participant DS as Distribution System
Note over Client,AP1: Client connected to AP1
Client->>Client: Monitors signal strength (RSSI)
Client->>Client: Signal from AP1 weakening
Note over Client: Discovery Phase
Client->>Client: Scan for nearby APs
Client->>AP2: Probe Request
AP2->>Client: Probe Response
Note over Client: Decision Phase
Client->>Client: Evaluate AP options<br/>(signal, load, capabilities)
Client->>Client: Select AP2 as target
Note over Client,AP2: Reassociation Phase
Client->>AP2: Authentication Request
AP2->>Client: Authentication Response
Client->>AP2: Reassociation Request
AP2->>DS: Notify about client reassociation
DS->>AP1: Forward reassociation notice
AP1->>DS: Release client context
AP2->>Client: Reassociation Response
Note over Client,AP2: Client now connected to AP2
Roaming Standards Comparison
| Standard | Purpose | Key Benefit | Typical Latency |
|---|---|---|---|
| Legacy | Basic roaming | Simple implementation | 50-100ms+ |
| 802.11r | Fast BSS Transition | Reduced authentication time | <10ms |
| 802.11k | Radio Resource Management | Better AP selection | N/A (decision aid) |
| 802.11v | Network Management | Network-assisted roaming | Improved efficiency |
| 802.11w | Protected Management Frames | Security against attacks | N/A (security) |
802.11r
- Also known as Fast BSS Transition (FT).
- Released: 2008.
- Purpose: Improves the speed of the handoff process between access points.
- Notes: Reduces the time required for re-authentication when a device moves from one AP to another.
Technical Details of 802.11r
802.11r, also known as Fast BSS Transition (FT), is a standard that aims to improve the handoff process between access points (APs) in a wireless network. This is particularly important for applications that require seamless connectivity, such as VoIP (Voice over IP) and real-time video streaming. Here are some key technical details:
-
Key Caching:
- 802.11r introduces the concept of key caching, which allows a client device to reuse the Pairwise Master Key (PMK) from a previous connection when roaming to a new AP. This reduces the time required for re-authentication.
-
Fast Transition (FT) Protocol:
- The FT protocol defines two methods for fast transitions: over-the-air and over-the-DS (Distribution System).
- Over-the-Air: The client communicates directly with the target AP to perform the handoff.
- Over-the-DS: The client communicates with the target AP through the current AP, using the wired network (DS) as an intermediary.
- The FT protocol defines two methods for fast transitions: over-the-air and over-the-DS (Distribution System).
-
Reduced Latency:
- By minimizing the time required for re-authentication and key exchange, 802.11r significantly reduces the latency associated with roaming. This is crucial for maintaining the quality of real-time applications.
-
FT Initial Mobility Domain Association:
- When a client first associates with an AP in an 802.11r-enabled network, it performs an FT Initial Mobility Domain Association. This process establishes the necessary security context and prepares the client for fast transitions within the mobility domain.
-
Mobility Domain Information Element (MDIE):
- The MDIE is included in the beacon frames and probe responses of 802.11r-enabled APs. It provides information about the mobility domain, allowing client devices to identify and connect to APs that support fast transitions.
-
Fast BSS Transition Information Element (FTIE):
- The FTIE is used during the authentication and reassociation processes to carry the necessary cryptographic information for fast transitions. It ensures that the security context is properly established and maintained during the handoff.
-
Compatibility:
- 802.11r is designed to be backward compatible with non-802.11r devices. APs can support both 802.11r and non-802.11r clients simultaneously, ensuring a smooth transition for devices that do not support the standard.
-
Key Hierarchy in 802.11r:
- The key hierarchy in 802.11r builds upon the existing 802.11i security framework:
- PMK (Pairwise Master Key): Derived from the initial 802.1X authentication, cached for reuse during fast transitions
- PMK-R0: First-level derivation from PMK, includes the mobility domain identifier (MDID)
- PMK-R1: Second-level derivation, specific to the target AP's R1 Key Holder
- PTK (Pairwise Transient Key): Final session key derived from PMK-R1 during reassociation
- This hierarchical approach enables APs to derive session keys without contacting the authentication server
- The key hierarchy in 802.11r builds upon the existing 802.11i security framework:
-
Mobility Domain:
- A Mobility Domain (MD) is a group of APs that share the same security context and allow fast transitions
- All APs in the same MD advertise the same Mobility Domain Identifier (MDID) in their beacons
- Clients that associate with an AP in an MD can roam to any other AP in the same MD using fast transitions
- The MD eliminates the need for full re-authentication when moving between APs
-
R0 and R1 Key Holders:
- R0 Key Holder (R0KH): Typically the RADIUS/authentication server or a centralized controller that holds PMK-R0
- R1 Key Holder (R1KH): Each AP in the mobility domain, holds PMK-R1 keys for clients
- During FT, the target AP's R1KH requests the PMK-R1 from the R0KH or retrieves it from the current AP
By implementing these technical features, 802.11r enhances the efficiency and reliability of the roaming process, providing a better user experience in environments with multiple access points.
802.11r Key Hierarchy
graph TB
MSK[MSK - Master Session Key<br/>from 802.1X/EAP Auth]
MSK --> PMK[PMK - Pairwise Master Key<br/>Derived via PRF]
PMK --> PMKR0[PMK-R0<br/>Includes: SSID + MDID + R0KHID<br/>Stored at R0 Key Holder]
PMKR0 --> PMKR1_AP1[PMK-R1 for AP1<br/>Includes: R1KHID AP1<br/>Stored at AP1]
PMKR0 --> PMKR1_AP2[PMK-R1 for AP2<br/>Includes: R1KHID AP2<br/>Stored at AP2]
PMKR0 --> PMKR1_AP3[PMK-R1 for AP3<br/>Includes: R1KHID AP3<br/>Stored at AP3]
PMKR1_AP1 --> PTK1[PTK for Session with AP1<br/>Ephemeral, per-association]
PMKR1_AP2 --> PTK2[PTK for Session with AP2<br/>Ephemeral, per-association]
PMKR1_AP3 --> PTK3[PTK for Session with AP3<br/>Ephemeral, per-association]
PTK1 --> Traffic1[Encrypted Traffic]
PTK2 --> Traffic2[Encrypted Traffic]
PTK3 --> Traffic3[Encrypted Traffic]
style MSK fill:#FFE6E6
style PMK fill:#FFE6E6
style PMKR0 fill:#FFF9E6
style PMKR1_AP1 fill:#E6F3FF
style PMKR1_AP2 fill:#E6F3FF
style PMKR1_AP3 fill:#E6F3FF
style PTK1 fill:#E6FFE6
style PTK2 fill:#E6FFE6
style PTK3 fill:#E6FFE6
802.11r Fast BSS Transition Flow
sequenceDiagram
participant Client
participant CurrentAP as Current AP (AP1)
participant TargetAP as Target AP (AP2)
participant DS as Distribution System
participant AuthServer as Authentication Server
Note over Client,CurrentAP: Initial Association with FT
Client->>CurrentAP: Association Request (MDIE, RSN)
CurrentAP->>AuthServer: Full 802.1X Authentication
AuthServer->>CurrentAP: PMK (Pairwise Master Key)
CurrentAP->>CurrentAP: Derive PTK, GTK
CurrentAP->>Client: Association Response (MDIE, FTIE)
Note over Client,CurrentAP: Client now part of Mobility Domain
Note over Client: Signal weakening, client scans
Client->>TargetAP: Probe Request
TargetAP->>Client: Probe Response (MDIE, FT-capable)
alt Over-the-Air FT
Note over Client,TargetAP: Fast Transition (Over-the-Air)
Client->>TargetAP: FT Authentication Request (FTIE, MDIE)
TargetAP->>CurrentAP: Request PMK context (via DS)
CurrentAP->>TargetAP: Transfer PMK context
TargetAP->>TargetAP: Derive new PTK using PMK
TargetAP->>Client: FT Authentication Response (FTIE)
Client->>TargetAP: FT Reassociation Request
TargetAP->>Client: FT Reassociation Response
Note over Client,TargetAP: Handoff complete (~10ms)
else Over-the-DS FT
Note over Client,DS: Fast Transition (Over-the-DS)
Client->>CurrentAP: FT Request (to TargetAP)
CurrentAP->>TargetAP: Forward FT Request + PMK
TargetAP->>TargetAP: Derive new PTK
TargetAP->>CurrentAP: FT Response
CurrentAP->>Client: FT Response
Client->>TargetAP: FT Reassociation Request
TargetAP->>Client: FT Reassociation Response
Note over Client,TargetAP: Handoff complete (~10ms)
end
Key Differences: Legacy vs 802.11r
flowchart TD
subgraph Legacy["Legacy Roaming (50-100ms)"]
L1[Scan & Probe] --> L2[Open Authentication]
L2 --> L3[Full 802.1X Auth]
L3 --> L4[4-Way Handshake]
L4 --> L5[Reassociation]
end
subgraph FT["802.11r Fast Transition (<10ms)"]
F1[Scan & Probe] --> F2[FT Authentication<br/>with PMK Cache]
F2 --> F3[FT Reassociation<br/>PTK Derived]
style F3 fill:#90EE90
end
style Legacy fill:#FFB6C6
style FT fill:#B6FFB6
802.11k
- Also known as Radio Resource Management (RRM).
- Released: 2008.
- Purpose: Provides mechanisms for measuring and reporting the radio environment.
- Notes: Helps devices make better roaming decisions by providing information about neighboring APs.
Technical Details of 802.11k
802.11k, also known as Radio Resource Management (RRM), is a standard that provides mechanisms for measuring and reporting the radio environment. This information helps client devices make better roaming decisions by providing data about neighboring access points (APs). Here are some key technical details:
-
Neighbor Reports:
- 802.11k enables APs to provide neighbor reports to client devices. These reports contain information about nearby APs, including their signal strength, channel, and supported data rates. This helps clients identify the best AP to roam to.
-
Beacon Reports:
- Client devices can request beacon reports from APs. These reports include details about the beacons received from neighboring APs, such as signal strength and channel utilization. This information assists clients in making informed roaming decisions.
-
Channel Load Reports:
- APs can provide channel load reports, which indicate the level of traffic on a particular channel. This helps client devices avoid congested channels and select APs operating on less crowded frequencies.
-
Noise Histogram Reports:
- Noise histogram reports provide information about the noise levels on different channels. By analyzing these reports, client devices can avoid channels with high levels of interference, improving overall network performance.
-
Transmit Stream/Category Measurement Reports:
- These reports provide data on the performance of specific traffic streams or categories. This helps client devices assess the quality of service (QoS) provided by different APs and make better roaming decisions based on their specific needs.
-
Location Tracking:
- 802.11k supports location tracking features, allowing APs to track the location of client devices within the network. This information can be used to optimize network performance and improve the accuracy of neighbor reports.
-
Link Measurement Reports:
- Link measurement reports provide detailed information about the quality of the wireless link between the client device and the AP. This includes metrics such as signal-to-noise ratio (SNR) and packet error rate (PER), which help clients evaluate the performance of their current connection and potential target APs.
By implementing these technical features, 802.11k enhances the ability of client devices to make informed roaming decisions, leading to improved network performance and a better user experience in environments with multiple access points.
802.11k Neighbor Discovery and Reporting
sequenceDiagram
participant Client
participant AP1 as Current AP
participant AP2 as Neighbor AP
participant AP3 as Neighbor AP
Note over Client,AP1: Client associated with AP1
rect rgb(230, 240, 255)
Note over Client,AP1: Neighbor Report Request
Client->>AP1: Neighbor Report Request
AP1->>AP1: Generate neighbor list<br/>(AP2, AP3, etc.)
AP1->>Client: Neighbor Report Response<br/>(AP info: BSSID, Channel, PHY)
end
rect rgb(255, 240, 230)
Note over Client: Beacon Report Request (optional)
Client->>Client: Request detailed beacon info
AP1->>Client: Beacon Request (scan AP2, AP3)
Client->>Client: Passive/Active scan
Client->>AP1: Beacon Report<br/>(RSSI, Channel Load, etc.)
end
rect rgb(240, 255, 240)
Note over Client: Radio Measurement Reports
Client->>Client: Analyze reports:<br/>• Signal strength (RSSI)<br/>• Channel utilization<br/>• Noise histogram<br/>• Link quality
Client->>Client: Select best AP for roaming
end
Note over Client,AP2: Client roams to AP2
802.11k Report Types
graph TB
subgraph RRM["802.11k Radio Resource Management"]
NR[Neighbor Report]
BR[Beacon Report]
CLR[Channel Load Report]
NHR[Noise Histogram]
LMR[Link Measurement]
NR -->|Provides| NR1[BSSID, Channel,<br/>Operating Class]
BR -->|Provides| BR1[RSSI, Beacon Interval,<br/>Capability Info]
CLR -->|Provides| CLR1[Channel Busy %,<br/>Medium Utilization]
NHR -->|Provides| NHR1[Interference Levels<br/>per Channel]
LMR -->|Provides| LMR1[SNR, Packet Error Rate,<br/>Transmit Power]
end
NR1 --> Decision[Intelligent<br/>Roaming Decision]
BR1 --> Decision
CLR1 --> Decision
NHR1 --> Decision
LMR1 --> Decision
style Decision fill:#90EE90
style RRM fill:#E6F3FF
802.11v
- Also known as Wireless Network Management.
- Released: 2011.
- Purpose: Enhances network management by providing mechanisms for configuring client devices.
- Notes: Includes features like BSS Transition Management, which helps devices roam more efficiently.
Technical Details of 802.11v
802.11v, also known as Wireless Network Management, is a standard that enhances network management by providing mechanisms for configuring client devices. This standard includes several features that improve the efficiency and performance of wireless networks. Here are some key technical details:
-
BSS Transition Management:
- 802.11v provides BSS Transition Management, which helps client devices make better roaming decisions. APs can suggest the best APs for clients to roam to, based on factors like signal strength and load.
-
Network Assisted Power Savings:
- This feature allows APs to provide information to client devices about the best times to enter power-saving modes. By coordinating power-saving activities, 802.11v helps extend battery life for client devices.
-
Traffic Filtering Service (TFS):
- TFS enables APs to filter traffic for client devices, reducing the amount of unnecessary data that clients need to process. This helps improve the efficiency of the network and reduces power consumption for client devices.
-
Wireless Network Management (WNM) Sleep Mode:
- WNM Sleep Mode allows client devices to enter a low-power sleep state while remaining connected to the network. APs can buffer data for sleeping clients and deliver it when they wake up, improving power efficiency without sacrificing connectivity.
-
Diagnostic and Reporting:
- 802.11v includes mechanisms for diagnostic and reporting, allowing APs and client devices to exchange information about network performance and issues. This helps network administrators identify and resolve problems more quickly.
-
Location Services:
- The standard supports location services, enabling APs to provide location-based information to client devices. This can be used for applications like asset tracking and location-based services.
By implementing these technical features, 802.11v enhances the management and performance of wireless networks, leading to improved efficiency, better power management, and a more reliable user experience in environments with multiple access points.
802.11v BSS Transition Management
sequenceDiagram
participant Client
participant AP1 as Current AP
participant AP2 as Target AP (Preferred)
participant Controller as Network Controller
Note over Client,AP1: Client connected to AP1
rect rgb(255, 230, 230)
Note over AP1,Controller: Network-Initiated Roaming
Controller->>AP1: Detect client should move<br/>(load balancing/signal)
AP1->>Client: BTM Request<br/>(Candidate List: AP2)
Note over Client: Candidate List includes:<br/>• BSSID of target APs<br/>• Operating class & channel<br/>• Preference values
end
rect rgb(230, 255, 230)
Note over Client: Client Decision
Client->>Client: Evaluate BTM candidates<br/>+ own scan results
Client->>Client: Select AP2 (highest preference)
Client->>AP1: BTM Response (Accept)
end
rect rgb(230, 230, 255)
Note over Client,AP2: Roaming Process
Client->>AP2: Authentication & Reassociation
AP2->>Client: Association Response
Note over Client,AP2: Client now on AP2
end
Client->>AP2: BTM Status Report (optional)
802.11v Features Overview
graph LR
subgraph WNM["802.11v Wireless Network Management"]
BTM[BSS Transition<br/>Management]
DMS[Directed Multicast<br/>Service]
FMS[Flexible Multicast<br/>Service]
TFS[Traffic Filtering<br/>Service]
Sleep[WNM Sleep Mode]
end
BTM -->|Benefit| BTM1[Network-assisted<br/>roaming decisions]
DMS -->|Benefit| DMS1[Efficient multicast<br/>delivery]
FMS -->|Benefit| FMS1[Scheduled multicast<br/>for power saving]
TFS -->|Benefit| TFS1[Filter unwanted<br/>traffic at AP]
Sleep -->|Benefit| Sleep1[Deep sleep while<br/>maintaining connection]
BTM1 --> Outcome[Better Performance<br/>& Battery Life]
DMS1 --> Outcome
FMS1 --> Outcome
TFS1 --> Outcome
Sleep1 --> Outcome
style Outcome fill:#90EE90
style WNM fill:#FFE6F0
Client-Initiated vs Network-Initiated Roaming
flowchart TB
subgraph Client["Client-Initiated (Legacy)"]
C1[Client monitors RSSI] --> C2[Signal weakens]
C2 --> C3[Client scans all channels]
C3 --> C4[Client selects AP]
C4 --> C5[Client initiates roam]
end
subgraph Network["Network-Initiated (802.11v)"]
N1[AP/Controller monitors<br/>client conditions] --> N2[AP detects poor signal<br/>or load imbalance]
N2 --> N3[AP sends BTM Request<br/>with candidate list]
N3 --> N4[Client evaluates<br/>suggestions]
N4 --> N5[Client roams to<br/>recommended AP]
end
style Client fill:#FFE6E6
style Network fill:#E6FFE6
802.11w
- Also known as Protected Management Frames (PMF).
- Released: 2009.
- Purpose: Enhances the security of management frames.
- Notes: Protects against certain types of attacks, such as deauthentication and disassociation attacks.
Technical Details of 802.11w
802.11w, also known as Protected Management Frames (PMF), is a standard that enhances the security of management frames in wireless networks. This standard provides mechanisms to protect against certain types of attacks, such as deauthentication and disassociation attacks. Here are some key technical details:
-
Management Frame Protection:
- 802.11w provides protection for management frames, which are used for network control and signaling. By securing these frames, the standard helps prevent attackers from disrupting network operations.
-
Protected Management Frames (PMF):
- PMF ensures that management frames are both encrypted and authenticated. This prevents unauthorized devices from injecting malicious management frames into the network.
-
Robust Security Network (RSN) Associations:
- 802.11w requires the use of RSN associations, which provide a secure method for devices to join the network. This includes the use of cryptographic techniques to protect the integrity and confidentiality of management frames.
-
Replay Protection:
- The standard includes mechanisms to protect against replay attacks, where an attacker captures and retransmits management frames to disrupt network operations. By using sequence numbers and timestamps, 802.11w ensures that management frames cannot be reused maliciously.
-
Deauthentication and Disassociation Protection:
- 802.11w specifically addresses deauthentication and disassociation attacks, where an attacker forces a device to disconnect from the network. By securing these management frames, the standard helps maintain stable and reliable network connections.
-
Cryptographic Protection Mechanisms:
- IGTK (Integrity Group Temporal Key): Used to protect broadcast/multicast management frames
- BIP (Broadcast/Multicast Integrity Protocol): Default integrity algorithm, uses AES-128-CMAC
- MIC (Message Integrity Code): Appended to protected management frames to verify authenticity
- SA Query Mechanism: Allows clients to verify the authenticity of disassociation/deauthentication frames
-
Protected Frame Types:
- Disassociation: Protected to prevent forced disconnection attacks
- Deauthentication: Protected to prevent session hijacking
- Robust Management Frames: Action frames related to QoS, spectrum management, and fast BSS transition
- Unprotected Frames: Beacon, Probe Request/Response, and Authentication frames remain unprotected for compatibility
-
PMF Modes:
- Optional (PMF=1): Client can connect with or without PMF support
- Required (PMF=2): Client must support PMF to connect (WPA3 requirement)
- Mixed mode allows gradual migration from legacy to protected networks
By implementing these technical features, 802.11w enhances the security of wireless networks, protecting against various types of attacks and ensuring the integrity and reliability of network operations.
802.11w Protected Management Frames
sequenceDiagram
participant Attacker
participant Client
participant AP
Note over Client,AP: Without 802.11w (Vulnerable)
rect rgb(255, 200, 200)
Attacker->>Client: Deauth Frame (Spoofed)
Client->>Client: Disconnect from AP
Note over Client: Connection disrupted!
end
Note over Client,AP: With 802.11w (Protected)
rect rgb(200, 255, 200)
Client->>AP: Initial 4-Way Handshake
AP->>Client: Establish IGTK (Integrity Group Temporal Key)
Attacker->>Client: Deauth Frame (Spoofed)
Client->>Client: Verify frame integrity<br/>using MIC (Message Integrity Code)
Client->>Client: Invalid MIC - Frame rejected
Note over Client,AP: Connection maintained!
end
Comprehensive Roaming Comparison
Legacy vs Modern Roaming
sequenceDiagram
participant C as Client
participant OldAP as Current AP
participant NewAP as Target AP
participant Auth as Auth Server
rect rgb(255, 230, 230)
Note over C,Auth: Legacy Roaming (~100ms)
C->>C: Scan all channels (20-50ms)
C->>NewAP: Probe Request
NewAP->>C: Probe Response
C->>NewAP: Authentication Request
NewAP->>C: Authentication Response
C->>NewAP: Association Request
NewAP->>Auth: Full 802.1X (30-50ms)
Auth->>NewAP: PMK
NewAP->>C: 4-Way Handshake (20ms)
C->>NewAP: Reassociation Complete
end
Note over C: ---
rect rgb(230, 255, 230)
Note over C,Auth: Modern Roaming with 802.11r/k/v (<20ms)
OldAP->>C: Neighbor Report (802.11k)
OldAP->>C: BTM Request (802.11v)
C->>C: Targeted scan (5-10ms)
C->>NewAP: FT Authentication (802.11r)
NewAP->>OldAP: Request PMK context
OldAP->>NewAP: Transfer PMK
NewAP->>C: FT Authentication Response
C->>NewAP: FT Reassociation (<10ms)
Note over C,NewAP: Roaming complete!
end
How the Standards Work Together
graph TB
Client[WiFi Client Device]
subgraph Discovery["Discovery Phase (802.11k)"]
K1[Request Neighbor Report] --> K2[Receive AP List with<br/>RSSI, Channel, Load]
K2 --> K3[Targeted Scanning]
end
subgraph Decision["Decision Phase (802.11v)"]
V1[Receive BTM Request] --> V2[Evaluate Candidates<br/>+ Network Suggestions]
V2 --> V3[Select Optimal AP]
end
subgraph Transition["Transition Phase (802.11r)"]
R1[FT Authentication<br/>with PMK Cache] --> R2[Fast Key Derivation]
R2 --> R3[FT Reassociation<br/><10ms]
end
subgraph Security["Security (802.11w)"]
W1[Protected Management<br/>Frames] --> W2[Prevent Deauth Attacks]
W2 --> W3[Secure Roaming Process]
end
Client --> Discovery
Discovery --> Decision
Decision --> Transition
Transition --> Security
Security --> Connected[Connected to New AP]
style Discovery fill:#E6F3FF
style Decision fill:#FFE6F0
style Transition fill:#E6FFE6
style Security fill:#FFF9E6
style Connected fill:#90EE90
Performance Comparison Table
| Roaming Aspect | Legacy | With 802.11r | With 802.11r/k/v | Full r/k/v/w |
|---|---|---|---|---|
| Latency | 50-100ms | <10ms | <10ms | <10ms |
| AP Discovery | Full scan (all channels) | Full scan | Targeted scan | Targeted scan |
| Decision Making | Client-only | Client-only | Network-assisted | Network-assisted |
| Authentication | Full 802.1X | PMK caching | PMK caching | PMK caching |
| Security | Vulnerable to deauth | Vulnerable to deauth | Vulnerable to deauth | Protected |
| VoIP Quality | May experience dropouts | Seamless | Seamless | Seamless |
| Battery Impact | High (full scans) | Medium | Low (targeted) | Low (targeted) |
| Best For | Simple networks | Fast handoffs | Enterprise | Enterprise + Security |
Real-World Roaming Timeline
gantt
title Roaming Process Timeline Comparison
dateFormat X
axisFormat %Lms
section Legacy
Channel Scanning :0, 50
Authentication :50, 30
4-Way Handshake :80, 20
Total (100ms) :0, 100
section 802.11r Only
Channel Scanning :0, 40
FT Authentication :40, 5
FT Reassociation :45, 5
Total (50ms) :0, 50
section 802.11r+k+v
Targeted Scanning :0, 10
FT Authentication :10, 5
FT Reassociation :15, 5
Total (20ms) :0, 20
Implementation Considerations
Enabling Fast Roaming
To achieve optimal roaming performance, consider the following:
-
Network Requirements:
- All APs must support the same roaming standards (802.11r/k/v/w)
- APs should be part of the same Mobility Domain (for 802.11r)
- Backend infrastructure must support PMK caching and distribution
-
Configuration Best Practices:
- Enable FT over-the-DS for better performance in dense deployments
- Configure neighbor reports accurately with current AP information
- Set appropriate BTM preference values to guide client decisions
- Ensure PMF (802.11w) is enabled for security
-
Client Support:
- Verify client devices support the required standards
- Update client drivers and firmware for best compatibility
- Test roaming behavior with target applications (VoIP, video conferencing)
-
Tuning Parameters:
- RSSI thresholds for roaming triggers (typically -70 to -75 dBm)
- Channel overlap and interference considerations
- Load balancing thresholds for BTM requests
- Roaming retry intervals and timeouts
Common Deployment Scenarios
graph LR
subgraph Scenario1["Enterprise Office"]
S1[High Density APs] --> S1A[802.11r/k/v/w<br/>All Enabled]
S1A --> S1B[VoIP & Video<br/>Optimized]
end
subgraph Scenario2["Public Venue"]
S2[Medium Density APs] --> S2A[802.11r/k<br/>Minimum]
S2A --> S2B[Basic Mobility<br/>Support]
end
subgraph Scenario3["Home/Small Office"]
S3[2-3 APs] --> S3A[Legacy or<br/>802.11r Only]
S3A --> S3B[Simple Setup<br/>Acceptable]
end
style Scenario1 fill:#E6FFE6
style Scenario2 fill:#FFF9E6
style Scenario3 fill:#FFE6E6
Troubleshooting and Monitoring
Common Roaming Issues
-
Sticky Client Problem:
- Symptom: Client stays connected to distant AP despite closer APs being available
- Cause: Client roaming algorithm too conservative, high RSSI disconnect threshold
- Solution: Use 802.11v BTM to encourage roaming, adjust AP minimum RSSI settings
-
Ping-Pong Roaming:
- Symptom: Client rapidly switches between two APs
- Cause: APs have overlapping coverage with similar signal strength
- Solution: Adjust AP transmit power, implement roaming hysteresis, use 802.11k/v
-
Failed Fast Transitions:
- Symptom: Roaming takes longer than expected or fails completely
- Cause: PMK not properly distributed, Mobility Domain misconfiguration
- Solution: Verify all APs share same MDID, check R0KH/R1KH communication
-
Authentication Timeouts:
- Symptom: Client disconnects during roaming attempt
- Cause: Slow authentication server, network latency
- Solution: Enable 802.11r PMK caching, optimize RADIUS server response time
Monitoring Tools and Metrics
graph TB
subgraph Metrics["Key Roaming Metrics"]
M1[Roaming Latency<br/>Target: <50ms]
M2[Roaming Success Rate<br/>Target: >95%]
M3[Average RSSI<br/>at Roaming Decision<br/>Target: -70 to -75 dBm]
M4[Failed Authentications<br/>Target: <1%]
M5[Client Association Time<br/>Target: <100ms]
end
subgraph Tools["Monitoring Tools"]
T1[Wireless Controller Logs]
T2[RADIUS Server Logs]
T3[Client-Side Tools<br/>iw, wpa_supplicant]
T4[Packet Capture<br/>Wireshark, tcpdump]
T5[Network Management System]
end
Metrics --> Analysis[Roaming<br/>Performance Analysis]
Tools --> Analysis
Analysis --> Actions[Optimization Actions]
style Metrics fill:#E6F3FF
style Tools fill:#FFE6F0
style Analysis fill:#FFF9E6
style Actions fill:#90EE90
Debugging Commands
Linux Client (iw/wpa_supplicant):
# Check current connection and roaming capabilities
iw dev wlan0 link
iw dev wlan0 scan | grep -E "BSS|SSID|freq|signal|capability"
# Monitor roaming events in real-time
wpa_cli -i wlan0
> status
> bss_flush 0
> scan
> scan_results
# Check FT (802.11r) capabilities
iw dev wlan0 scan | grep -A 20 "your-ssid" | grep -E "FT|Mobility"
# Monitor neighbor reports (802.11k)
iw dev wlan0 station dump
Access Point (hostapd):
# Enable debug logging for roaming events
hostapd -dd /etc/hostapd/hostapd.conf | grep -E "FT|BTM|neighbor"
# Check associated clients and their roaming status
hostapd_cli all_sta
hostapd_cli status
# Send BSS transition management request
hostapd_cli bss_tm_req <client-mac> neighbor=<target-bssid>,<op-class>,<channel>
Wireshark Filters for Roaming Analysis:
# 802.11r Fast BSS Transition frames
wlan.fc.type_subtype == 0x000b || wlan.fc.type_subtype == 0x000c
# Authentication and Reassociation
wlan.fc.type_subtype == 0x0000 || wlan.fc.type_subtype == 0x0001 ||
wlan.fc.type_subtype == 0x0002 || wlan.fc.type_subtype == 0x0003
# 802.11k Neighbor Reports
wlan.tag.number == 52
# 802.11v BSS Transition Management
wlan.fixed.action_code == 7 || wlan.fixed.action_code == 8
# 802.11w Protected Management Frames
wlan.fc.protected == 1 && (wlan.fc.type == 0)
Performance Optimization Tips
-
Channel Planning:
- Use non-overlapping channels (1, 6, 11 for 2.4 GHz)
- Minimize co-channel interference in high-density deployments
- Consider DFS channels in 5 GHz for additional capacity
-
AP Placement and Power:
- Ensure 20-30% cell overlap for seamless roaming
- Reduce AP transmit power in dense deployments to prevent sticky clients
- Use site survey tools to validate coverage
-
RSSI Thresholds:
- Set roaming trigger at -70 to -75 dBm for optimal performance
- Configure minimum RSSI for association rejection at -80 to -85 dBm
- Implement different thresholds for 2.4 GHz vs 5 GHz
-
Fast Roaming Configuration:
- Enable FT over-the-DS for centralized architectures
- Configure neighbor reports with accurate channel and BSSID information
- Set appropriate BTM preference values to guide client decisions
- Ensure PMK caching timeout (default 43200 seconds) is appropriate
Sample wpa_supplicant Configuration
network={
ssid="YourNetwork"
psk="YourPassword"
key_mgmt=WPA-PSK WPA-PSK-SHA256 FT-PSK
ieee80211w=2 # Require PMF (802.11w)
# Fast roaming settings
proactive_key_caching=1 # Enable opportunistic key caching
ft_eap_pmksa_caching=1 # Enable PMK caching for FT
# Roaming aggressiveness (0-3, 0 = disabled, 3 = most aggressive)
bgscan="simple:30:-70:3600" # Background scanning
# Scan frequency configuration
scan_freq=2412 2437 2462 5180 5200 5220 5240 5745 5765 5785 5805
}
Sample hostapd Configuration
# Basic settings
interface=wlan0
driver=nl80211
ssid=YourNetwork
wpa=2
wpa_key_mgmt=WPA-PSK FT-PSK
wpa_pairwise=CCMP
# 802.11r Fast BSS Transition
mobility_domain=a1b2 # Same for all APs in the domain
ft_over_ds=1 # Enable FT over Distribution System
ft_psk_generate_local=1 # Generate PMK-R0/R1 locally
nas_identifier=ap1.example.com # Unique per AP
r0kh=02:00:00:00:03:00 ap1.example.com 000102030405060708090a0b0c0d0e0f
r1kh=02:00:00:00:03:00 00:00:00:00:03:00 000102030405060708090a0b0c0d0e0f
# 802.11k Radio Resource Management
rrm_neighbor_report=1
rrm_beacon_report=1
# 802.11v BSS Transition Management
bss_transition=1
wnm_sleep_mode=1
time_advertisement=2
# 802.11w Protected Management Frames
ieee80211w=2 # Required (2) or Optional (1)
group_mgmt_cipher=AES-128-CMAC
# Roaming optimization
ap_max_inactivity=300
disassoc_low_ack=1
skip_inactivity_poll=0
References and Further Reading
-
IEEE Standards:
- IEEE 802.11r-2008: Fast BSS Transition
- IEEE 802.11k-2008: Radio Resource Measurement
- IEEE 802.11v-2011: Wireless Network Management
- IEEE 802.11w-2009: Protected Management Frames
-
RFCs and Documentation:
- RFC 5416: Control and Provisioning of Wireless Access Points (CAPWAP) Protocol
- Wi-Fi Alliance: WPA3 Security Specification
- hostapd documentation: https://w1.fi/hostapd/
- wpa_supplicant documentation: https://w1.fi/wpa_supplicant/
-
Best Practices:
- Cisco Enterprise Mobility Design Guide
- Aruba Best Practices for High-Density WiFi Deployments
- Ruckus SmartRoam Technology Overview
QoS Management
QoS Map
MSCS
SCS
DSCP Policy
Machine Learning
A comprehensive guide to machine learning concepts, algorithms, and implementations.
Table of Contents
- Supervised Learning - Classification, regression, and supervised algorithms
- Unsupervised Learning - Clustering and dimensionality reduction
- Reinforcement Learning - RL concepts, Q-learning, and policy gradients
- Deep Learning - Neural networks, CNNs, RNNs, and training techniques
- Neural Networks - Architecture, backpropagation, activation functions
- Deep Reinforcement Learning - DQN, A3C, PPO, and advanced RL
- Generative Models - GANs, VAEs, and flow-based models
- Deep Generative Models - Advanced generative architectures
- Transfer Learning - Pre-training, fine-tuning, and domain adaptation
- PyTorch - Deep learning framework, tensors, autograd, training
- NumPy - Foundational numerical computing for ML implementations
- Quantization - Model compression, INT8/INT4 quantization, GPTQ, AWQ
- Transformers - Attention mechanisms, BERT, GPT architectures
- Hugging Face - Transformers library, models, and datasets
- Interesting Papers - Key ML papers and summaries
Overview
Machine Learning is a field of artificial intelligence that focuses on building systems that learn from data. The field can be broadly categorized into:
Supervised Learning
Learning from labeled data where each example has an input-output pair. The goal is to learn a mapping from inputs to outputs.
- Classification: Predicting discrete categories (e.g., spam/not spam)
- Regression: Predicting continuous values (e.g., house prices)
Unsupervised Learning
Learning patterns from unlabeled data without explicit output labels.
- Clustering: Grouping similar data points together
- Dimensionality Reduction: Reducing the number of features while preserving information
- Anomaly Detection: Identifying outliers in data
Reinforcement Learning
Learning through interaction with an environment to maximize cumulative rewards.
- Model-free RL: Learning without modeling the environment
- Model-based RL: Learning a model of the environment
- Deep RL: Combining deep learning with reinforcement learning
Key Concepts
The Machine Learning Pipeline
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report
# 1. Load and prepare data
X, y = load_data() # Features and labels
# 2. Split data
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# 3. Preprocess data
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# 4. Train model
model = LogisticRegression()
model.fit(X_train_scaled, y_train)
# 5. Evaluate
y_pred = model.predict(X_test_scaled)
print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}")
print(classification_report(y_test, y_pred))
Bias-Variance Tradeoff
The bias-variance tradeoff is a fundamental concept in machine learning:
- Bias: Error from overly simplistic assumptions (underfitting)
- Variance: Error from sensitivity to small fluctuations in training data (overfitting)
- Total Error = Bias² + Variance + Irreducible Error
import matplotlib.pyplot as plt
from sklearn.model_selection import learning_curve
def plot_learning_curve(estimator, X, y):
train_sizes, train_scores, val_scores = learning_curve(
estimator, X, y, cv=5, n_jobs=-1,
train_sizes=np.linspace(0.1, 1.0, 10)
)
train_mean = np.mean(train_scores, axis=1)
train_std = np.std(train_scores, axis=1)
val_mean = np.mean(val_scores, axis=1)
val_std = np.std(val_scores, axis=1)
plt.figure(figsize=(10, 6))
plt.plot(train_sizes, train_mean, label='Training score')
plt.plot(train_sizes, val_mean, label='Validation score')
plt.fill_between(train_sizes, train_mean - train_std,
train_mean + train_std, alpha=0.1)
plt.fill_between(train_sizes, val_mean - val_std,
val_mean + val_std, alpha=0.1)
plt.xlabel('Training Set Size')
plt.ylabel('Score')
plt.legend()
plt.title('Learning Curve')
plt.show()
Cross-Validation
Cross-validation helps assess model performance and reduce overfitting:
from sklearn.model_selection import cross_val_score, KFold
# K-Fold Cross-Validation
kfold = KFold(n_splits=5, shuffle=True, random_state=42)
scores = cross_val_score(model, X, y, cv=kfold, scoring='accuracy')
print(f"Cross-validation scores: {scores}")
print(f"Mean accuracy: {scores.mean():.4f} (+/- {scores.std() * 2:.4f})")
# Stratified K-Fold (maintains class distribution)
from sklearn.model_selection import StratifiedKFold
skfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
scores = cross_val_score(model, X, y, cv=skfold, scoring='accuracy')
Regularization
Regularization techniques help prevent overfitting:
L1 Regularization (Lasso): Encourages sparsity
Loss = MSE + λ * Σ|w_i|
L2 Regularization (Ridge): Penalizes large weights
Loss = MSE + λ * Σw_i²
Elastic Net: Combines L1 and L2
Loss = MSE + λ₁ * Σ|w_i| + λ₂ * Σw_i²
from sklearn.linear_model import Lasso, Ridge, ElasticNet
# L1 Regularization
lasso = Lasso(alpha=0.1)
lasso.fit(X_train, y_train)
# L2 Regularization
ridge = Ridge(alpha=1.0)
ridge.fit(X_train, y_train)
# Elastic Net
elastic = ElasticNet(alpha=0.1, l1_ratio=0.5)
elastic.fit(X_train, y_train)
Feature Engineering
Feature engineering is crucial for model performance:
import pandas as pd
from sklearn.preprocessing import PolynomialFeatures, OneHotEncoder
# Polynomial features
poly = PolynomialFeatures(degree=2, include_bias=False)
X_poly = poly.fit_transform(X)
# One-hot encoding for categorical variables
encoder = OneHotEncoder(sparse=False, drop='first')
X_encoded = encoder.fit_transform(df[['category1', 'category2']])
# Feature scaling
from sklearn.preprocessing import MinMaxScaler, RobustScaler
# Min-Max scaling (0 to 1)
minmax_scaler = MinMaxScaler()
X_minmax = minmax_scaler.fit_transform(X)
# Robust scaling (uses median and IQR)
robust_scaler = RobustScaler()
X_robust = robust_scaler.fit_transform(X)
Hyperparameter Tuning
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from sklearn.ensemble import RandomForestClassifier
# Grid Search
param_grid = {
'n_estimators': [100, 200, 300],
'max_depth': [10, 20, 30, None],
'min_samples_split': [2, 5, 10],
'min_samples_leaf': [1, 2, 4]
}
grid_search = GridSearchCV(
RandomForestClassifier(),
param_grid,
cv=5,
scoring='accuracy',
n_jobs=-1
)
grid_search.fit(X_train, y_train)
print(f"Best parameters: {grid_search.best_params_}")
# Randomized Search (faster for large parameter spaces)
from scipy.stats import randint, uniform
param_distributions = {
'n_estimators': randint(100, 500),
'max_depth': randint(10, 50),
'min_samples_split': randint(2, 20),
'min_samples_leaf': randint(1, 10)
}
random_search = RandomizedSearchCV(
RandomForestClassifier(),
param_distributions,
n_iter=100,
cv=5,
random_state=42,
n_jobs=-1
)
random_search.fit(X_train, y_train)
Evaluation Metrics
Classification Metrics
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score,
confusion_matrix, roc_auc_score, roc_curve
)
# Basic metrics
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred, average='weighted')
recall = recall_score(y_test, y_pred, average='weighted')
f1 = f1_score(y_test, y_pred, average='weighted')
# Confusion matrix
cm = confusion_matrix(y_test, y_pred)
print(cm)
# ROC-AUC
y_proba = model.predict_proba(X_test)[:, 1]
auc = roc_auc_score(y_test, y_proba)
fpr, tpr, thresholds = roc_curve(y_test, y_proba)
# Plot ROC curve
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, label=f'ROC curve (AUC = {auc:.2f})')
plt.plot([0, 1], [0, 1], 'k--', label='Random classifier')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend()
plt.show()
Regression Metrics
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
# Mean Squared Error
mse = mean_squared_error(y_test, y_pred)
rmse = np.sqrt(mse)
# Mean Absolute Error
mae = mean_absolute_error(y_test, y_pred)
# R² Score
r2 = r2_score(y_test, y_pred)
print(f"RMSE: {rmse:.4f}")
print(f"MAE: {mae:.4f}")
print(f"R² Score: {r2:.4f}")
Common Pitfalls and Best Practices
1. Data Leakage
Ensure test data doesn't leak into training:
# WRONG: Scaling before splitting
X_scaled = scaler.fit_transform(X)
X_train, X_test = train_test_split(X_scaled)
# CORRECT: Fit scaler only on training data
X_train, X_test = train_test_split(X)
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
2. Class Imbalance
Handle imbalanced datasets:
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import RandomUnderSampler
# SMOTE (Synthetic Minority Over-sampling)
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(X_train, y_train)
# Class weights
model = LogisticRegression(class_weight='balanced')
3. Feature Selection
Remove irrelevant features:
from sklearn.feature_selection import (
SelectKBest, f_classif, RFE, SelectFromModel
)
# Univariate feature selection
selector = SelectKBest(f_classif, k=10)
X_selected = selector.fit_transform(X_train, y_train)
# Recursive Feature Elimination
rfe = RFE(estimator=LogisticRegression(), n_features_to_select=10)
X_rfe = rfe.fit_transform(X_train, y_train)
# Model-based selection
sfm = SelectFromModel(RandomForestClassifier(), threshold='median')
X_sfm = sfm.fit_transform(X_train, y_train)
Resources
-
Books:
- "Pattern Recognition and Machine Learning" by Christopher Bishop
- "The Elements of Statistical Learning" by Hastie, Tibshirani, and Friedman
- "Deep Learning" by Goodfellow, Bengio, and Courville
-
Courses:
- Andrew Ng's Machine Learning Course (Coursera)
- Fast.ai Practical Deep Learning
- Stanford CS229: Machine Learning
-
Libraries:
- scikit-learn: Traditional ML algorithms
- PyTorch: Deep learning framework
- TensorFlow/Keras: Deep learning framework
- XGBoost/LightGBM: Gradient boosting
- Hugging Face: Transformers and NLP
Quick Reference
Model Selection Guide
| Problem Type | Recommended Models |
|---|---|
| Linear separable data | Logistic Regression, SVM (linear) |
| Non-linear data | Random Forest, XGBoost, Neural Networks |
| High-dimensional data | Ridge/Lasso Regression, SVM |
| Small dataset | SVM, Naive Bayes, Linear Models |
| Large dataset | SGD-based models, Deep Learning |
| Interpretability needed | Linear Models, Decision Trees |
| Image data | CNNs, Vision Transformers |
| Text data | Transformers, RNNs, TF-IDF + Classical ML |
| Time series | RNNs, LSTMs, Transformers, ARIMA |
| Tabular data | XGBoost, LightGBM, Random Forest |
Performance Optimization
# Use efficient data structures
import pandas as pd
df = pd.read_csv('data.csv', dtype={'col1': 'category'})
# Parallel processing
from joblib import Parallel, delayed
results = Parallel(n_jobs=-1)(delayed(process)(x) for x in data)
# Batch processing for large datasets
def batch_process(data, batch_size=1000):
for i in range(0, len(data), batch_size):
batch = data[i:i+batch_size]
yield process_batch(batch)
# Use generators for memory efficiency
def data_generator(file_path):
for chunk in pd.read_csv(file_path, chunksize=1000):
yield chunk
Deep Learning
Deep learning uses artificial neural networks with multiple layers to learn hierarchical representations of data.
Table of Contents
- Neural Networks Fundamentals
- Activation Functions
- Loss Functions
- Optimization
- Regularization
- Convolutional Neural Networks (CNNs)
- Recurrent Neural Networks (RNNs)
- Attention Mechanisms
- Batch Normalization
- Advanced Architectures
Neural Networks Fundamentals
Perceptron
The basic building block of neural networks.
Mathematical Formulation:
y = f(Σ(w_i * x_i) + b)
Where:
- x_i: inputs
- w_i: weights
- b: bias
- f: activation function
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
# Simple perceptron
class Perceptron(nn.Module):
def __init__(self, input_dim):
super(Perceptron, self).__init__()
self.linear = nn.Linear(input_dim, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
return self.sigmoid(self.linear(x))
# Example usage
input_dim = 4
model = Perceptron(input_dim)
x = torch.randn(32, input_dim)
output = model(x)
print(f"Output shape: {output.shape}")
Multi-Layer Perceptron (MLP)
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dims, output_dim):
super(MLP, self).__init__()
layers = []
prev_dim = input_dim
# Hidden layers
for hidden_dim in hidden_dims:
layers.append(nn.Linear(prev_dim, hidden_dim))
layers.append(nn.ReLU())
prev_dim = hidden_dim
# Output layer
layers.append(nn.Linear(prev_dim, output_dim))
self.network = nn.Sequential(*layers)
def forward(self, x):
return self.network(x)
# Example: 3-layer MLP
model = MLP(input_dim=10, hidden_dims=[64, 32], output_dim=2)
x = torch.randn(16, 10)
output = model(x)
print(f"Output shape: {output.shape}")
Backpropagation
Forward Pass:
z^[l] = W^[l] · a^[l-1] + b^[l]
a^[l] = g^[l](z^[l])
Backward Pass (Chain Rule):
dL/dW^[l] = dL/da^[l] · da^[l]/dz^[l] · dz^[l]/dW^[l]
# Manual backpropagation example
class SimpleNN:
def __init__(self, input_size, hidden_size, output_size):
self.W1 = np.random.randn(input_size, hidden_size) * 0.01
self.b1 = np.zeros((1, hidden_size))
self.W2 = np.random.randn(hidden_size, output_size) * 0.01
self.b2 = np.zeros((1, output_size))
def sigmoid(self, z):
return 1 / (1 + np.exp(-z))
def sigmoid_derivative(self, a):
return a * (1 - a)
def forward(self, X):
self.z1 = np.dot(X, self.W1) + self.b1
self.a1 = self.sigmoid(self.z1)
self.z2 = np.dot(self.a1, self.W2) + self.b2
self.a2 = self.sigmoid(self.z2)
return self.a2
def backward(self, X, y, learning_rate=0.01):
m = X.shape[0]
# Output layer gradients
dz2 = self.a2 - y
dW2 = (1/m) * np.dot(self.a1.T, dz2)
db2 = (1/m) * np.sum(dz2, axis=0, keepdims=True)
# Hidden layer gradients
dz1 = np.dot(dz2, self.W2.T) * self.sigmoid_derivative(self.a1)
dW1 = (1/m) * np.dot(X.T, dz1)
db1 = (1/m) * np.sum(dz1, axis=0, keepdims=True)
# Update weights
self.W2 -= learning_rate * dW2
self.b2 -= learning_rate * db2
self.W1 -= learning_rate * dW1
self.b1 -= learning_rate * db1
Activation Functions
Common Activation Functions
import torch.nn.functional as F
# ReLU (Rectified Linear Unit)
def relu(x):
return torch.max(torch.zeros_like(x), x)
# Leaky ReLU
def leaky_relu(x, alpha=0.01):
return torch.where(x > 0, x, alpha * x)
# Sigmoid
def sigmoid(x):
return 1 / (1 + torch.exp(-x))
# Tanh
def tanh(x):
return torch.tanh(x)
# Softmax (for multi-class classification)
def softmax(x, dim=-1):
exp_x = torch.exp(x - torch.max(x, dim=dim, keepdim=True)[0])
return exp_x / torch.sum(exp_x, dim=dim, keepdim=True)
# GELU (Gaussian Error Linear Unit)
def gelu(x):
return 0.5 * x * (1 + torch.tanh(np.sqrt(2/np.pi) * (x + 0.044715 * x**3)))
# Swish/SiLU
def swish(x):
return x * torch.sigmoid(x)
# Visualization
x = torch.linspace(-5, 5, 100)
activations = {
'ReLU': F.relu(x),
'Leaky ReLU': F.leaky_relu(x, 0.1),
'Sigmoid': torch.sigmoid(x),
'Tanh': torch.tanh(x),
'GELU': F.gelu(x),
'Swish': x * torch.sigmoid(x)
}
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for ax, (name, y) in zip(axes.flatten(), activations.items()):
ax.plot(x.numpy(), y.numpy())
ax.set_title(name)
ax.grid(True)
ax.axhline(y=0, color='k', linestyle='--', alpha=0.3)
ax.axvline(x=0, color='k', linestyle='--', alpha=0.3)
plt.tight_layout()
plt.show()
Loss Functions
Classification Losses
# Binary Cross-Entropy
def binary_cross_entropy(predictions, targets):
return -torch.mean(targets * torch.log(predictions + 1e-8) +
(1 - targets) * torch.log(1 - predictions + 1e-8))
# Categorical Cross-Entropy
def categorical_cross_entropy(predictions, targets):
return -torch.mean(torch.sum(targets * torch.log(predictions + 1e-8), dim=1))
# Focal Loss (for imbalanced datasets)
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
return focal_loss.mean()
# Using PyTorch built-in losses
criterion_bce = nn.BCELoss()
criterion_ce = nn.CrossEntropyLoss()
criterion_nll = nn.NLLLoss()
# Example
predictions = torch.rand(32, 10)
targets = torch.randint(0, 10, (32,))
loss = criterion_ce(predictions, targets)
Regression Losses
# Mean Squared Error (MSE)
criterion_mse = nn.MSELoss()
# Mean Absolute Error (MAE)
criterion_mae = nn.L1Loss()
# Smooth L1 Loss (Huber Loss)
criterion_smooth = nn.SmoothL1Loss()
# Custom loss example
class CustomRegressionLoss(nn.Module):
def __init__(self):
super(CustomRegressionLoss, self).__init__()
def forward(self, predictions, targets):
mse = torch.mean((predictions - targets) ** 2)
mae = torch.mean(torch.abs(predictions - targets))
return mse + 0.1 * mae
Optimization
Optimizers
# Stochastic Gradient Descent (SGD)
optimizer_sgd = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# Adam (Adaptive Moment Estimation)
optimizer_adam = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
# AdamW (Adam with weight decay)
optimizer_adamw = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
# RMSprop
optimizer_rmsprop = optim.RMSprop(model.parameters(), lr=0.001, alpha=0.99)
# Adagrad
optimizer_adagrad = optim.Adagrad(model.parameters(), lr=0.01)
# Training loop
def train_epoch(model, dataloader, criterion, optimizer, device):
model.train()
total_loss = 0
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(device), target.to(device)
# Forward pass
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
# Backward pass
loss.backward()
# Gradient clipping (optional)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Update weights
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
Learning Rate Scheduling
from torch.optim.lr_scheduler import (
StepLR, ExponentialLR, CosineAnnealingLR,
ReduceLROnPlateau, OneCycleLR
)
# Step decay
scheduler_step = StepLR(optimizer, step_size=10, gamma=0.1)
# Exponential decay
scheduler_exp = ExponentialLR(optimizer, gamma=0.95)
# Cosine annealing
scheduler_cosine = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)
# Reduce on plateau
scheduler_plateau = ReduceLROnPlateau(
optimizer, mode='min', factor=0.1, patience=10, verbose=True
)
# One Cycle Policy
scheduler_onecycle = OneCycleLR(
optimizer, max_lr=0.01, epochs=100, steps_per_epoch=len(train_loader)
)
# Usage in training loop
for epoch in range(num_epochs):
train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
val_loss = validate(model, val_loader, criterion, device)
# Step the scheduler
scheduler_plateau.step(val_loss) # For ReduceLROnPlateau
# OR
scheduler_step.step() # For other schedulers
print(f"Epoch {epoch}: LR = {optimizer.param_groups[0]['lr']:.6f}")
Regularization
Dropout
class MLPWithDropout(nn.Module):
def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate=0.5):
super(MLPWithDropout, self).__init__()
layers = []
prev_dim = input_dim
for hidden_dim in hidden_dims:
layers.append(nn.Linear(prev_dim, hidden_dim))
layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout_rate))
prev_dim = hidden_dim
layers.append(nn.Linear(prev_dim, output_dim))
self.network = nn.Sequential(*layers)
def forward(self, x):
return self.network(x)
# Dropout variants
dropout = nn.Dropout(p=0.5) # Standard dropout
dropout_2d = nn.Dropout2d(p=0.5) # For Conv2d
dropout_3d = nn.Dropout3d(p=0.5) # For Conv3d
alpha_dropout = nn.AlphaDropout(p=0.5) # For SELU activation
Weight Decay (L2 Regularization)
# Weight decay in optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)
# Manual L2 regularization
def l2_regularization(model, lambda_l2=0.01):
l2_loss = 0
for param in model.parameters():
l2_loss += torch.norm(param, 2)
return lambda_l2 * l2_loss
# In training loop
loss = criterion(output, target) + l2_regularization(model)
Early Stopping
class EarlyStopping:
def __init__(self, patience=10, min_delta=0, mode='min'):
self.patience = patience
self.min_delta = min_delta
self.mode = mode
self.counter = 0
self.best_score = None
self.early_stop = False
def __call__(self, val_loss):
score = -val_loss if self.mode == 'min' else val_loss
if self.best_score is None:
self.best_score = score
elif score < self.best_score + self.min_delta:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.counter = 0
return self.early_stop
# Usage
early_stopping = EarlyStopping(patience=10)
for epoch in range(num_epochs):
train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
val_loss = validate(model, val_loader, criterion, device)
if early_stopping(val_loss):
print(f"Early stopping at epoch {epoch}")
break
Convolutional Neural Networks
CNNs are specialized for processing grid-like data (images, videos).
Basic CNN Architecture
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super(SimpleCNN, self).__init__()
# Convolutional layers
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
# Pooling
self.pool = nn.MaxPool2d(2, 2)
# Fully connected layers
self.fc1 = nn.Linear(128 * 4 * 4, 512)
self.fc2 = nn.Linear(512, num_classes)
# Dropout
self.dropout = nn.Dropout(0.5)
def forward(self, x):
# Conv block 1
x = self.pool(F.relu(self.conv1(x))) # 32x32 -> 16x16
# Conv block 2
x = self.pool(F.relu(self.conv2(x))) # 16x16 -> 8x8
# Conv block 3
x = self.pool(F.relu(self.conv3(x))) # 8x8 -> 4x4
# Flatten
x = x.view(x.size(0), -1)
# Fully connected
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
# Example usage
model = SimpleCNN(num_classes=10)
x = torch.randn(4, 3, 32, 32) # Batch of 4 RGB 32x32 images
output = model(x)
print(f"Output shape: {output.shape}") # [4, 10]
Modern CNN Architectures
ResNet (Residual Networks)
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
residual = x
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(residual)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 3, 1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(64, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(64, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(128, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(256, 512, num_blocks[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, num_classes)
def _make_layer(self, in_channels, out_channels, num_blocks, stride):
layers = []
layers.append(ResidualBlock(in_channels, out_channels, stride))
for _ in range(1, num_blocks):
layers.append(ResidualBlock(out_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# ResNet-18
model = ResNet([2, 2, 2, 2])
Inception Module
class InceptionModule(nn.Module):
def __init__(self, in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, out_pool):
super(InceptionModule, self).__init__()
# 1x1 conv branch
self.branch1 = nn.Sequential(
nn.Conv2d(in_channels, out_1x1, kernel_size=1),
nn.ReLU()
)
# 3x3 conv branch
self.branch2 = nn.Sequential(
nn.Conv2d(in_channels, red_3x3, kernel_size=1),
nn.ReLU(),
nn.Conv2d(red_3x3, out_3x3, kernel_size=3, padding=1),
nn.ReLU()
)
# 5x5 conv branch
self.branch3 = nn.Sequential(
nn.Conv2d(in_channels, red_5x5, kernel_size=1),
nn.ReLU(),
nn.Conv2d(red_5x5, out_5x5, kernel_size=5, padding=2),
nn.ReLU()
)
# Max pooling branch
self.branch4 = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
nn.Conv2d(in_channels, out_pool, kernel_size=1),
nn.ReLU()
)
def forward(self, x):
branch1 = self.branch1(x)
branch2 = self.branch2(x)
branch3 = self.branch3(x)
branch4 = self.branch4(x)
return torch.cat([branch1, branch2, branch3, branch4], dim=1)
Advanced CNN Techniques
# Depthwise Separable Convolution
class DepthwiseSeparableConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3):
super(DepthwiseSeparableConv, self).__init__()
# Depthwise convolution
self.depthwise = nn.Conv2d(
in_channels, in_channels, kernel_size,
padding=kernel_size//2, groups=in_channels
)
# Pointwise convolution
self.pointwise = nn.Conv2d(in_channels, out_channels, 1)
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)
return x
# Squeeze-and-Excitation Block
class SEBlock(nn.Module):
def __init__(self, channels, reduction=16):
super(SEBlock, self).__init__()
self.squeeze = nn.AdaptiveAvgPool2d(1)
self.excitation = nn.Sequential(
nn.Linear(channels, channels // reduction, bias=False),
nn.ReLU(),
nn.Linear(channels // reduction, channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.squeeze(x).view(b, c)
y = self.excitation(y).view(b, c, 1, 1)
return x * y.expand_as(x)
Recurrent Neural Networks
RNNs process sequential data by maintaining hidden state.
Basic RNN
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_layers=1):
super(SimpleRNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
# x shape: (batch, seq_len, input_size)
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
# RNN output
out, hn = self.rnn(x, h0)
# out shape: (batch, seq_len, hidden_size)
# Use last time step
out = self.fc(out[:, -1, :])
return out
# Example
model = SimpleRNN(input_size=10, hidden_size=64, output_size=2)
x = torch.randn(32, 20, 10) # (batch, seq_len, features)
output = model(x)
print(f"Output shape: {output.shape}") # [32, 2]
LSTM (Long Short-Term Memory)
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_layers=2, dropout=0.2):
super(LSTMModel, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(
input_size, hidden_size, num_layers,
batch_first=True, dropout=dropout if num_layers > 1 else 0
)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
# Initialize hidden and cell states
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
# LSTM forward
out, (hn, cn) = self.lstm(x, (h0, c0))
# Use last time step
out = self.fc(out[:, -1, :])
return out
# Bidirectional LSTM
class BiLSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_layers=2):
super(BiLSTM, self).__init__()
self.lstm = nn.LSTM(
input_size, hidden_size, num_layers,
batch_first=True, bidirectional=True
)
# Multiply by 2 for bidirectional
self.fc = nn.Linear(hidden_size * 2, output_size)
def forward(self, x):
out, _ = self.lstm(x)
out = self.fc(out[:, -1, :])
return out
GRU (Gated Recurrent Unit)
class GRUModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_layers=2, dropout=0.2):
super(GRUModel, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.gru = nn.GRU(
input_size, hidden_size, num_layers,
batch_first=True, dropout=dropout if num_layers > 1 else 0
)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
out, hn = self.gru(x, h0)
out = self.fc(out[:, -1, :])
return out
# Comparison
models = {
'RNN': SimpleRNN(10, 64, 2),
'LSTM': LSTMModel(10, 64, 2),
'GRU': GRUModel(10, 64, 2)
}
for name, model in models.items():
params = sum(p.numel() for p in model.parameters())
print(f"{name} parameters: {params}")
Attention Mechanisms
Attention allows models to focus on relevant parts of the input.
Self-Attention
class SelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads=8):
super(SelfAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
self.query = nn.Linear(embed_dim, embed_dim)
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
self.out = nn.Linear(embed_dim, embed_dim)
def forward(self, x, mask=None):
batch_size = x.size(0)
# Linear projections
Q = self.query(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
K = self.key(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
V = self.value(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = F.softmax(scores, dim=-1)
attention_output = torch.matmul(attention_weights, V)
# Concatenate heads
attention_output = attention_output.transpose(1, 2).contiguous()
attention_output = attention_output.view(batch_size, -1, self.embed_dim)
# Final linear layer
output = self.out(attention_output)
return output, attention_weights
# Example
attention = SelfAttention(embed_dim=512, num_heads=8)
x = torch.randn(32, 10, 512) # (batch, seq_len, embed_dim)
output, weights = attention(x)
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")
Transformer Block
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
super(TransformerBlock, self).__init__()
self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.feed_forward = nn.Sequential(
nn.Linear(embed_dim, ff_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(ff_dim, embed_dim)
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Multi-head attention
attn_output, _ = self.attention(x, x, x, attn_mask=mask)
x = self.norm1(x + self.dropout(attn_output))
# Feed-forward
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
Batch Normalization
Normalizes layer inputs to improve training.
class ConvBNReLU(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
# Other normalization techniques
# Layer Normalization (better for RNNs/Transformers)
layer_norm = nn.LayerNorm(normalized_shape=[128])
# Group Normalization
group_norm = nn.GroupNorm(num_groups=8, num_channels=64)
# Instance Normalization (used in style transfer)
instance_norm = nn.InstanceNorm2d(num_features=64)
Advanced Architectures
Vision Transformer (ViT)
class PatchEmbedding(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super(PatchEmbedding, self).__init__()
self.num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x) # (B, embed_dim, H/P, W/P)
x = x.flatten(2) # (B, embed_dim, num_patches)
x = x.transpose(1, 2) # (B, num_patches, embed_dim)
return x
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, num_classes=1000,
embed_dim=768, depth=12, num_heads=12):
super(VisionTransformer, self).__init__()
self.patch_embed = PatchEmbedding(img_size, patch_size, 3, embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, embed_dim * 4)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
x = x + self.pos_embed
for block in self.blocks:
x = block(x)
x = self.norm(x)
x = x[:, 0] # Use cls token
x = self.head(x)
return x
Practical Tips
- Initialize Weights Properly: Use Xavier/He initialization
- Monitor Gradients: Check for vanishing/exploding gradients
- Use Mixed Precision Training: Faster training with similar accuracy
- Data Augmentation: Improves generalization
- Gradient Accumulation: Train with larger effective batch sizes
- Model Checkpointing: Save best models during training
# Mixed precision training
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for epoch in range(num_epochs):
for data, target in train_loader:
optimizer.zero_grad()
with autocast():
output = model(data)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Resources
- "Deep Learning" by Goodfellow, Bengio, and Courville
- PyTorch Documentation: https://pytorch.org/docs/
- TensorFlow Documentation: https://www.tensorflow.org/
- Papers with Code: https://paperswithcode.com/
Neural Networks
Overview
A neural network is a machine learning model inspired by biological brains. It consists of interconnected nodes (neurons) organized in layers that learn patterns from data.
Basic Architecture
Input Layer Hidden Layers Output Layer
o o o
o o o
o o o
o o
o o o
o o
[n inputs] [hidden units] [output units]
Key Components
Neurons
Each neuron applies transformation: $\text{output} = \text{activation}(\text{weights} \cdot \text{inputs} + \text{bias})$
Activation Functions
| Function | Formula | Range | Use Case |
|---|---|---|---|
| ReLU | $\max(0, x)$ | $[0, \infty)$ | Hidden layers |
| Sigmoid | $\frac{1}{1+e^{-x}}$ | $(0, 1)$ | Binary classification |
| Tanh | $\frac{e^x - e^{-x}}{e^x + e^{-x}}$ | $(-1, 1)$ | Hidden layers |
| Softmax | $\frac{e^{x_i}}{\sum_j e^{x_j}}$ | $(0, 1)$ probabilities | Multi-class output |
| Linear | $x$ | $(-\infty, \infty)$ | Regression output |
Layers
- Input Layer: Raw data (28x28 pixels, word embeddings, etc.)
- Hidden Layers: Learn complex patterns through non-linear transformations
- Output Layer: Final predictions
Training Process
Forward Pass
Input flows through network:
x → w1 + b1 → activation → ... → output
Loss Function
Measures prediction error:
- MSE (regression): Mean squared error
- Cross-Entropy (classification): Measures probability difference
Backpropagation
Calculates gradients and updates weights:
1. Compute loss
2. Calculate gradients: ∂(loss)/∂(weights)
3. Update weights: w = w - learning_rate × gradient
4. Repeat
Optimizers
| Optimizer | Learning | Best For |
|---|---|---|
| SGD | Fixed or decaying | Simple tasks |
| Momentum | Accelerated | Faster convergence |
| Adam | Adaptive | Most modern tasks |
| RMSprop | Adaptive | Deep networks |
Code Example (PyTorch)
import torch
import torch.nn as nn
from torch.optim import Adam
# Define network
class NeuralNetwork(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# Create network
model = NeuralNetwork(input_size=784, hidden_size=128, output_size=10)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)
# Training loop
for epoch in range(10):
for batch_x, batch_y in train_loader:
# Forward pass
outputs = model(batch_x)
loss = criterion(outputs, batch_y)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
Network Types
Feedforward Neural Networks (FNN)
- Data flows one direction only
- Simplest type, works for structured data
Convolutional Neural Networks (CNN)
- Specialized for image processing
- Uses filters to extract spatial features
- Reduces parameters through weight sharing
Recurrent Neural Networks (RNN)
- Processes sequences (text, time series)
- Maintains hidden state between inputs
- Variants: LSTM, GRU (better long-term memory)
Transformers
- Attention-based architecture
- Parallel processing of sequences
- Powers modern LLMs (GPT, BERT)
Hyperparameters
| Parameter | Impact | Typical Values |
|---|---|---|
| Learning Rate | Convergence speed, stability | 0.001 - 0.1 |
| Batch Size | Memory, stability | 32 - 256 |
| Hidden Units | Capacity | 64 - 2048 |
| Epochs | Training duration | 10 - 100 |
| Dropout | Regularization | 0.3 - 0.5 |
Training Tips
1. Data Preprocessing
# Normalize inputs
mean = X_train.mean()
std = X_train.std()
X_train = (X_train - mean) / std
2. Early Stopping
# Stop if validation loss doesn't improve
if val_loss > best_loss:
patience -= 1
if patience == 0:
break
best_loss = min(best_loss, val_loss)
3. Learning Rate Scheduling
# Decrease learning rate over time
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
for epoch in range(100):
# train...
scheduler.step()
4. Regularization
- L1/L2: Penalize large weights
- Dropout: Randomly disable neurons
- Batch Normalization: Normalize activations
Common Issues
| Problem | Cause | Solution |
|---|---|---|
| Underfitting | Model too simple | Increase hidden units, epochs |
| Overfitting | Model too complex | Add dropout, L2 regularization |
| Vanishing Gradients | Gradients $\to$ 0 | Use ReLU, batch norm |
| Exploding Gradients | Gradients $\to \infty$ | Gradient clipping |
Modern Architectures
ResNet (Residual Networks)
Skip connections prevent vanishing gradients in deep networks
Attention Mechanisms
Query-Key-Value mechanism enables transformers
Vision Transformers (ViT)
Apply transformer architecture to image patches
ELI10
Think of a neural network like learning to draw:
- Input Layer: You see a cat
- Hidden Layers: Brain recognizes ears -> whiskers -> tail (learns patterns)
- Output Layer: Brain says "This is a cat!"
The network learns by:
- Making predictions (forward pass)
- Checking if wrong (loss)
- Adjusting "how to recognize cats" (backprop)
- Repeating until accurate
More hidden layers = learns more complex patterns!
Further Resources
- Neural Networks Visualization
- 3Blue1Brown Neural Networks Series
- PyTorch Tutorials
- Deep Learning Book
Supervised Learning
Supervised learning is a type of machine learning where the model learns from labeled training data to make predictions on unseen data.
Table of Contents
- Classification
- Regression
- Linear Models
- Tree-Based Models
- Support Vector Machines
- Ensemble Methods
- Naive Bayes
- K-Nearest Neighbors
Classification
Classification predicts discrete class labels. The goal is to learn a decision boundary that separates different classes.
Binary Classification
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
# Generate synthetic data
X, y = make_classification(n_samples=1000, n_features=20, n_classes=2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Train logistic regression
model = LogisticRegression(max_iter=1000)
model.fit(X_train, y_train)
# Predictions
y_pred = model.predict(X_test)
y_proba = model.predict_proba(X_test)
# Evaluation
print(f"Accuracy: {accuracy_score(y_test, y_pred):.4f}")
print("\nConfusion Matrix:")
print(confusion_matrix(y_test, y_pred))
print("\nClassification Report:")
print(classification_report(y_test, y_pred))
Multi-class Classification
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
# Generate multi-class data
X, y = make_classification(
n_samples=1000,
n_features=20,
n_classes=5,
n_informative=15,
random_state=42
)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# Random Forest for multi-class
rf = RandomForestClassifier(n_estimators=100, random_state=42)
rf.fit(X_train, y_train)
# One-vs-Rest (OvR) strategy
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC
ovr = OneVsRestClassifier(SVC(kernel='rbf'))
ovr.fit(X_train, y_train)
# One-vs-One (OvO) strategy
from sklearn.multiclass import OneVsOneClassifier
ovo = OneVsOneClassifier(SVC(kernel='rbf'))
ovo.fit(X_train, y_train)
Imbalanced Classification
from imblearn.over_sampling import SMOTE, ADASYN
from imblearn.under_sampling import RandomUnderSampler, TomekLinks
from imblearn.combine import SMOTETomek
from sklearn.utils.class_weight import compute_class_weight
# SMOTE - Synthetic Minority Over-sampling Technique
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(X_train, y_train)
# ADASYN - Adaptive Synthetic Sampling
adasyn = ADASYN(random_state=42)
X_resampled, y_resampled = adasyn.fit_resample(X_train, y_train)
# Combined approach
smote_tomek = SMOTETomek(random_state=42)
X_resampled, y_resampled = smote_tomek.fit_resample(X_train, y_train)
# Class weights
class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
model = LogisticRegression(class_weight='balanced')
model.fit(X_train, y_train)
# Custom threshold
y_proba = model.predict_proba(X_test)[:, 1]
threshold = 0.3 # Lower threshold for minority class
y_pred_custom = (y_proba >= threshold).astype(int)
Regression
Regression predicts continuous values. The goal is to learn a function that maps inputs to outputs.
Linear Regression
Mathematical formulation:
y = β₀ + β₁x₁ + β₂x₂ + ... + βₙxₙ + ε
Where:
- y is the target variable
- x₁, x₂, ..., xₙ are features
- β₀, β₁, ..., βₙ are coefficients
- ε is the error term
from sklearn.linear_model import LinearRegression
from sklearn.datasets import make_regression
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
# Generate regression data
X, y = make_regression(n_samples=1000, n_features=10, noise=10, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# Train linear regression
lr = LinearRegression()
lr.fit(X_train, y_train)
# Predictions
y_pred = lr.predict(X_test)
# Evaluation
mse = mean_squared_error(y_test, y_pred)
rmse = np.sqrt(mse)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"RMSE: {rmse:.4f}")
print(f"MAE: {mae:.4f}")
print(f"R² Score: {r2:.4f}")
# Coefficients
print("\nCoefficients:", lr.coef_)
print("Intercept:", lr.intercept_)
Polynomial Regression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import Pipeline
# Create polynomial features
poly_features = PolynomialFeatures(degree=3, include_bias=False)
X_poly = poly_features.fit_transform(X_train)
# Using Pipeline
poly_model = Pipeline([
('poly_features', PolynomialFeatures(degree=3)),
('linear_regression', LinearRegression())
])
poly_model.fit(X_train, y_train)
y_pred_poly = poly_model.predict(X_test)
# Compare with linear
from sklearn.metrics import mean_squared_error
print(f"Linear RMSE: {np.sqrt(mean_squared_error(y_test, y_pred)):.4f}")
print(f"Polynomial RMSE: {np.sqrt(mean_squared_error(y_test, y_pred_poly)):.4f}")
Regularized Regression
Ridge Regression (L2):
Loss = Σ(y - ŷ)² + λΣβ²
Lasso Regression (L1):
Loss = Σ(y - ŷ)² + λΣ|β|
Elastic Net:
Loss = Σ(y - ŷ)² + λ₁Σ|β| + λ₂Σβ²
from sklearn.linear_model import Ridge, Lasso, ElasticNet, LassoCV, RidgeCV
# Ridge Regression
ridge = Ridge(alpha=1.0)
ridge.fit(X_train, y_train)
y_pred_ridge = ridge.predict(X_test)
# Lasso Regression (feature selection)
lasso = Lasso(alpha=0.1)
lasso.fit(X_train, y_train)
y_pred_lasso = lasso.predict(X_test)
# Check which features were selected by Lasso
feature_importance = np.abs(lasso.coef_)
selected_features = np.where(feature_importance > 0)[0]
print(f"Selected features: {selected_features}")
# Elastic Net
elastic = ElasticNet(alpha=0.1, l1_ratio=0.5)
elastic.fit(X_train, y_train)
# Cross-validated alpha selection
ridge_cv = RidgeCV(alphas=[0.1, 1.0, 10.0, 100.0])
ridge_cv.fit(X_train, y_train)
print(f"Best alpha: {ridge_cv.alpha_}")
lasso_cv = LassoCV(alphas=[0.01, 0.1, 1.0, 10.0], cv=5)
lasso_cv.fit(X_train, y_train)
print(f"Best alpha: {lasso_cv.alpha_}")
Linear Models
Logistic Regression
Binary classification using the sigmoid function:
P(y=1|x) = 1 / (1 + e^(-z))
where z = β₀ + β₁x₁ + ... + βₙxₙ
from sklearn.linear_model import LogisticRegression
# Binary classification
log_reg = LogisticRegression(
penalty='l2',
C=1.0, # Inverse of regularization strength
solver='lbfgs',
max_iter=1000
)
log_reg.fit(X_train, y_train)
# Get probabilities
probabilities = log_reg.predict_proba(X_test)
print("Class probabilities shape:", probabilities.shape)
# Decision boundary
decision_scores = log_reg.decision_function(X_test)
print("Decision scores shape:", decision_scores.shape)
# Multi-class logistic regression
multi_log_reg = LogisticRegression(multi_class='multinomial', solver='lbfgs')
multi_log_reg.fit(X_train, y_train)
Perceptron
from sklearn.linear_model import Perceptron
# Simple perceptron
perceptron = Perceptron(max_iter=1000, tol=1e-3, random_state=42)
perceptron.fit(X_train, y_train)
y_pred = perceptron.predict(X_test)
# Custom perceptron implementation
class CustomPerceptron:
def __init__(self, learning_rate=0.01, n_iterations=1000):
self.lr = learning_rate
self.n_iterations = n_iterations
self.weights = None
self.bias = None
def fit(self, X, y):
n_samples, n_features = X.shape
self.weights = np.zeros(n_features)
self.bias = 0
# Convert labels to -1 and 1
y_ = np.where(y <= 0, -1, 1)
for _ in range(self.n_iterations):
for idx, x_i in enumerate(X):
linear_output = np.dot(x_i, self.weights) + self.bias
y_predicted = np.sign(linear_output)
# Update weights if misclassified
update = self.lr * (y_[idx] - y_predicted)
self.weights += update * x_i
self.bias += update
def predict(self, X):
linear_output = np.dot(X, self.weights) + self.bias
return np.sign(linear_output)
# Train custom perceptron
custom_perc = CustomPerceptron()
custom_perc.fit(X_train, y_train)
Tree-Based Models
Decision Trees
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.tree import export_graphviz
import graphviz
# Classification tree
dt_clf = DecisionTreeClassifier(
max_depth=5,
min_samples_split=10,
min_samples_leaf=5,
criterion='gini' # or 'entropy'
)
dt_clf.fit(X_train, y_train)
# Regression tree
dt_reg = DecisionTreeRegressor(
max_depth=5,
min_samples_split=10,
min_samples_leaf=5
)
dt_reg.fit(X_train, y_train)
# Feature importance
importances = dt_clf.feature_importances_
indices = np.argsort(importances)[::-1]
print("Feature ranking:")
for i in range(min(10, len(indices))):
print(f"{i+1}. Feature {indices[i]} ({importances[indices[i]]:.4f})")
# Visualize tree
dot_data = export_graphviz(
dt_clf,
out_file=None,
feature_names=[f'feature_{i}' for i in range(X_train.shape[1])],
class_names=['class_0', 'class_1'],
filled=True,
rounded=True
)
# graph = graphviz.Source(dot_data)
Random Forest
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
# Random Forest Classifier
rf_clf = RandomForestClassifier(
n_estimators=100,
max_depth=10,
min_samples_split=10,
min_samples_leaf=4,
max_features='sqrt', # or 'log2', None
bootstrap=True,
oob_score=True, # Out-of-bag score
n_jobs=-1,
random_state=42
)
rf_clf.fit(X_train, y_train)
# Out-of-bag score
print(f"OOB Score: {rf_clf.oob_score_:.4f}")
# Random Forest Regressor
rf_reg = RandomForestRegressor(
n_estimators=100,
max_depth=10,
n_jobs=-1,
random_state=42
)
rf_reg.fit(X_train, y_train)
# Feature importance
feature_importance = pd.DataFrame({
'feature': [f'feature_{i}' for i in range(X_train.shape[1])],
'importance': rf_clf.feature_importances_
}).sort_values('importance', ascending=False)
print(feature_importance.head(10))
Gradient Boosting
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor
# Gradient Boosting Classifier
gb_clf = GradientBoostingClassifier(
n_estimators=100,
learning_rate=0.1,
max_depth=3,
subsample=0.8,
random_state=42
)
gb_clf.fit(X_train, y_train)
# Gradient Boosting Regressor
gb_reg = GradientBoostingRegressor(
n_estimators=100,
learning_rate=0.1,
max_depth=3,
random_state=42
)
gb_reg.fit(X_train, y_train)
# Feature importance
print("Feature importances:", gb_clf.feature_importances_)
XGBoost
import xgboost as xgb
from xgboost import XGBClassifier, XGBRegressor
# XGBoost Classifier
xgb_clf = XGBClassifier(
n_estimators=100,
max_depth=6,
learning_rate=0.1,
subsample=0.8,
colsample_bytree=0.8,
objective='binary:logistic',
random_state=42
)
xgb_clf.fit(X_train, y_train)
# XGBoost Regressor
xgb_reg = XGBRegressor(
n_estimators=100,
max_depth=6,
learning_rate=0.1,
subsample=0.8,
colsample_bytree=0.8,
random_state=42
)
xgb_reg.fit(X_train, y_train)
# Using DMatrix for better performance
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)
params = {
'max_depth': 6,
'eta': 0.1,
'objective': 'binary:logistic',
'eval_metric': 'auc'
}
# Train with early stopping
evals = [(dtrain, 'train'), (dtest, 'test')]
bst = xgb.train(
params,
dtrain,
num_boost_round=1000,
evals=evals,
early_stopping_rounds=50,
verbose_eval=False
)
# Predictions
y_pred_proba = bst.predict(dtest)
LightGBM
import lightgbm as lgb
# LightGBM Classifier
lgb_clf = lgb.LGBMClassifier(
n_estimators=100,
max_depth=6,
learning_rate=0.1,
num_leaves=31,
subsample=0.8,
colsample_bytree=0.8,
random_state=42
)
lgb_clf.fit(X_train, y_train)
# Using Dataset for better performance
train_data = lgb.Dataset(X_train, label=y_train)
test_data = lgb.Dataset(X_test, label=y_test, reference=train_data)
params = {
'objective': 'binary',
'metric': 'auc',
'num_leaves': 31,
'learning_rate': 0.1,
'feature_fraction': 0.8,
'bagging_fraction': 0.8,
'bagging_freq': 5
}
# Train
gbm = lgb.train(
params,
train_data,
num_boost_round=1000,
valid_sets=[test_data],
callbacks=[lgb.early_stopping(stopping_rounds=50)]
)
Support Vector Machines
SVM finds the hyperplane that maximizes the margin between classes.
Mathematical Formulation:
Minimize: (1/2)||w||² + C·Σξᵢ
Subject to: yᵢ(w·xᵢ + b) ≥ 1 - ξᵢ, ξᵢ ≥ 0
Linear SVM
from sklearn.svm import SVC, LinearSVC
# Linear SVM
linear_svm = LinearSVC(C=1.0, max_iter=10000)
linear_svm.fit(X_train, y_train)
# SVC with linear kernel
svc_linear = SVC(kernel='linear', C=1.0)
svc_linear.fit(X_train, y_train)
Non-linear SVM with Kernels
# RBF (Radial Basis Function) kernel
svc_rbf = SVC(kernel='rbf', C=1.0, gamma='scale')
svc_rbf.fit(X_train, y_train)
# Polynomial kernel
svc_poly = SVC(kernel='poly', degree=3, C=1.0)
svc_poly.fit(X_train, y_train)
# Sigmoid kernel
svc_sigmoid = SVC(kernel='sigmoid', C=1.0)
svc_sigmoid.fit(X_train, y_train)
# Custom kernel
def custom_kernel(X, Y):
return np.dot(X, Y.T)
svc_custom = SVC(kernel=custom_kernel)
svc_custom.fit(X_train, y_train)
SVM for Regression
from sklearn.svm import SVR
# Support Vector Regression
svr = SVR(kernel='rbf', C=1.0, epsilon=0.1)
svr.fit(X_train, y_train)
y_pred_svr = svr.predict(X_test)
# Linear SVR
from sklearn.svm import LinearSVR
linear_svr = LinearSVR(epsilon=0.1, C=1.0)
linear_svr.fit(X_train, y_train)
Ensemble Methods
Bagging
from sklearn.ensemble import BaggingClassifier, BaggingRegressor
from sklearn.tree import DecisionTreeClassifier
# Bagging with decision trees
bagging_clf = BaggingClassifier(
base_estimator=DecisionTreeClassifier(),
n_estimators=100,
max_samples=0.8,
max_features=0.8,
bootstrap=True,
oob_score=True,
n_jobs=-1,
random_state=42
)
bagging_clf.fit(X_train, y_train)
print(f"OOB Score: {bagging_clf.oob_score_:.4f}")
Boosting
AdaBoost:
from sklearn.ensemble import AdaBoostClassifier, AdaBoostRegressor
# AdaBoost Classifier
ada_clf = AdaBoostClassifier(
base_estimator=DecisionTreeClassifier(max_depth=1),
n_estimators=100,
learning_rate=1.0,
random_state=42
)
ada_clf.fit(X_train, y_train)
# AdaBoost Regressor
ada_reg = AdaBoostRegressor(
base_estimator=DecisionTreeRegressor(max_depth=3),
n_estimators=100,
learning_rate=1.0,
random_state=42
)
ada_reg.fit(X_train, y_train)
Stacking
from sklearn.ensemble import StackingClassifier, StackingRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.naive_bayes import GaussianNB
# Define base models
base_models = [
('lr', LogisticRegression()),
('dt', DecisionTreeClassifier()),
('svc', SVC(probability=True)),
('nb', GaussianNB())
]
# Stacking Classifier
stacking_clf = StackingClassifier(
estimators=base_models,
final_estimator=LogisticRegression(),
cv=5
)
stacking_clf.fit(X_train, y_train)
# Stacking Regressor
from sklearn.linear_model import Ridge, Lasso
from sklearn.tree import DecisionTreeRegressor
reg_base_models = [
('ridge', Ridge()),
('lasso', Lasso()),
('dt', DecisionTreeRegressor())
]
stacking_reg = StackingRegressor(
estimators=reg_base_models,
final_estimator=Ridge(),
cv=5
)
stacking_reg.fit(X_train, y_train)
Voting
from sklearn.ensemble import VotingClassifier, VotingRegressor
# Hard voting
voting_clf_hard = VotingClassifier(
estimators=base_models,
voting='hard'
)
voting_clf_hard.fit(X_train, y_train)
# Soft voting (uses predicted probabilities)
voting_clf_soft = VotingClassifier(
estimators=base_models,
voting='soft'
)
voting_clf_soft.fit(X_train, y_train)
# Voting Regressor
voting_reg = VotingRegressor(estimators=reg_base_models)
voting_reg.fit(X_train, y_train)
Naive Bayes
Based on Bayes' theorem with the "naive" assumption of feature independence:
P(y|x₁,...,xₙ) = P(y)·P(x₁,...,xₙ|y) / P(x₁,...,xₙ)
Gaussian Naive Bayes
from sklearn.naive_bayes import GaussianNB
# Gaussian NB (assumes features follow normal distribution)
gnb = GaussianNB()
gnb.fit(X_train, y_train)
y_pred = gnb.predict(X_test)
y_proba = gnb.predict_proba(X_test)
Multinomial Naive Bayes
from sklearn.naive_bayes import MultinomialNB
from sklearn.feature_extraction.text import CountVectorizer
# Example with text data
texts = ["I love this", "This is bad", "Great product", "Terrible experience"]
labels = [1, 0, 1, 0]
vectorizer = CountVectorizer()
X_text = vectorizer.fit_transform(texts)
mnb = MultinomialNB(alpha=1.0)
mnb.fit(X_text, labels)
Bernoulli Naive Bayes
from sklearn.naive_bayes import BernoulliNB
# Bernoulli NB (for binary/boolean features)
bnb = BernoulliNB(alpha=1.0)
bnb.fit(X_train, y_train)
K-Nearest Neighbors
KNN is a non-parametric method that classifies based on the k nearest training examples.
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
# KNN Classifier
knn_clf = KNeighborsClassifier(
n_neighbors=5,
weights='uniform', # or 'distance'
algorithm='auto', # 'ball_tree', 'kd_tree', 'brute'
metric='minkowski',
p=2 # p=2 for Euclidean, p=1 for Manhattan
)
knn_clf.fit(X_train, y_train)
# Distance-weighted KNN
knn_weighted = KNeighborsClassifier(n_neighbors=5, weights='distance')
knn_weighted.fit(X_train, y_train)
# KNN Regressor
knn_reg = KNeighborsRegressor(n_neighbors=5, weights='distance')
knn_reg.fit(X_train, y_train)
# Find optimal k
from sklearn.model_selection import cross_val_score
k_range = range(1, 31)
k_scores = []
for k in k_range:
knn = KNeighborsClassifier(n_neighbors=k)
scores = cross_val_score(knn, X_train, y_train, cv=5, scoring='accuracy')
k_scores.append(scores.mean())
optimal_k = k_range[np.argmax(k_scores)]
print(f"Optimal k: {optimal_k}")
Model Comparison
from sklearn.model_selection import cross_validate
import pandas as pd
# Define models to compare
models = {
'Logistic Regression': LogisticRegression(),
'Decision Tree': DecisionTreeClassifier(),
'Random Forest': RandomForestClassifier(),
'SVM': SVC(),
'KNN': KNeighborsClassifier(),
'Naive Bayes': GaussianNB(),
'XGBoost': XGBClassifier()
}
# Compare models
results = []
for name, model in models.items():
cv_results = cross_validate(
model, X_train, y_train,
cv=5,
scoring=['accuracy', 'precision', 'recall', 'f1'],
return_train_score=True
)
results.append({
'Model': name,
'Train Accuracy': cv_results['train_accuracy'].mean(),
'Test Accuracy': cv_results['test_accuracy'].mean(),
'Precision': cv_results['test_precision'].mean(),
'Recall': cv_results['test_recall'].mean(),
'F1': cv_results['test_f1'].mean()
})
# Display results
comparison_df = pd.DataFrame(results)
comparison_df = comparison_df.sort_values('Test Accuracy', ascending=False)
print(comparison_df)
Practical Tips
1. Data Preprocessing
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.impute import SimpleImputer
# Handle missing values
imputer = SimpleImputer(strategy='mean') # or 'median', 'most_frequent'
X_imputed = imputer.fit_transform(X)
# Feature scaling (important for SVM, KNN, Neural Networks)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
2. Feature Selection
from sklearn.feature_selection import SelectKBest, f_classif, RFE
# Univariate selection
selector = SelectKBest(f_classif, k=10)
X_selected = selector.fit_transform(X_train, y_train)
# Recursive Feature Elimination
rfe = RFE(estimator=RandomForestClassifier(), n_features_to_select=10)
X_rfe = rfe.fit_transform(X_train, y_train)
3. Pipeline
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
# Create pipeline
pipe = Pipeline([
('scaler', StandardScaler()),
('pca', PCA(n_components=10)),
('classifier', LogisticRegression())
])
pipe.fit(X_train, y_train)
y_pred = pipe.predict(X_test)
4. Hyperparameter Tuning
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
# Grid search
param_grid = {
'C': [0.1, 1, 10],
'kernel': ['rbf', 'linear'],
'gamma': ['scale', 'auto']
}
grid_search = GridSearchCV(SVC(), param_grid, cv=5, scoring='accuracy')
grid_search.fit(X_train, y_train)
print(f"Best parameters: {grid_search.best_params_}")
Resources
- scikit-learn documentation: https://scikit-learn.org/
- XGBoost documentation: https://xgboost.readthedocs.io/
- LightGBM documentation: https://lightgbm.readthedocs.io/
- "Introduction to Statistical Learning" by James et al.
- "Pattern Recognition and Machine Learning" by Bishop
Unsupervised Learning
Unsupervised learning discovers hidden patterns in data without labeled outputs.
Table of Contents
Clustering
Clustering groups similar data points together without predefined labels.
K-Means Clustering
K-Means partitions data into k clusters by minimizing within-cluster variance.
Algorithm:
- Initialize k centroids randomly
- Assign each point to nearest centroid
- Update centroids as mean of assigned points
- Repeat steps 2-3 until convergence
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs
# Generate synthetic data
X, y_true = make_blobs(n_samples=300, centers=4, cluster_std=0.6, random_state=42)
# K-Means clustering
kmeans = KMeans(n_clusters=4, random_state=42, n_init=10)
kmeans.fit(X)
# Predictions
y_pred = kmeans.predict(X)
centers = kmeans.cluster_centers_
# Visualization
plt.figure(figsize=(10, 6))
plt.scatter(X[:, 0], X[:, 1], c=y_pred, cmap='viridis', alpha=0.6)
plt.scatter(centers[:, 0], centers[:, 1], c='red', marker='X', s=200, edgecolors='black')
plt.title('K-Means Clustering')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.show()
# Cluster characteristics
print(f"Cluster centers:\n{centers}")
print(f"Inertia (sum of squared distances): {kmeans.inertia_:.2f}")
Choosing Optimal K
Elbow Method:
from sklearn.metrics import silhouette_score
# Elbow method
inertias = []
silhouettes = []
K_range = range(2, 11)
for k in K_range:
kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
kmeans.fit(X)
inertias.append(kmeans.inertia_)
silhouettes.append(silhouette_score(X, kmeans.labels_))
# Plot elbow curve
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.plot(K_range, inertias, 'bo-')
ax1.set_xlabel('Number of clusters (k)')
ax1.set_ylabel('Inertia')
ax1.set_title('Elbow Method')
ax1.grid(True)
ax2.plot(K_range, silhouettes, 'ro-')
ax2.set_xlabel('Number of clusters (k)')
ax2.set_ylabel('Silhouette Score')
ax2.set_title('Silhouette Analysis')
ax2.grid(True)
plt.tight_layout()
plt.show()
K-Means++
Improved initialization for K-Means:
# K-Means++ (default in scikit-learn)
kmeans_plus = KMeans(n_clusters=4, init='k-means++', random_state=42)
kmeans_plus.fit(X)
# Mini-batch K-Means (faster for large datasets)
from sklearn.cluster import MiniBatchKMeans
mini_kmeans = MiniBatchKMeans(n_clusters=4, random_state=42, batch_size=100)
mini_kmeans.fit(X)
Hierarchical Clustering
Builds a tree of clusters (dendrogram).
Agglomerative (Bottom-up):
from sklearn.cluster import AgglomerativeClustering
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import pdist
# Agglomerative clustering
agg_clustering = AgglomerativeClustering(
n_clusters=4,
linkage='ward' # 'complete', 'average', 'single'
)
y_pred_agg = agg_clustering.fit_predict(X)
# Create dendrogram
Z = linkage(X, method='ward')
plt.figure(figsize=(12, 6))
dendrogram(Z)
plt.title('Hierarchical Clustering Dendrogram')
plt.xlabel('Sample Index')
plt.ylabel('Distance')
plt.show()
# Different linkage methods
linkage_methods = ['ward', 'complete', 'average', 'single']
for method in linkage_methods:
agg = AgglomerativeClustering(n_clusters=4, linkage=method)
labels = agg.fit_predict(X)
print(f"{method.capitalize()} linkage - Silhouette: {silhouette_score(X, labels):.3f}")
DBSCAN
Density-Based Spatial Clustering finds core samples of high density.
from sklearn.cluster import DBSCAN
# DBSCAN
dbscan = DBSCAN(eps=0.5, min_samples=5)
y_pred_dbscan = dbscan.fit_predict(X)
# Number of clusters (excluding noise points labeled as -1)
n_clusters = len(set(y_pred_dbscan)) - (1 if -1 in y_pred_dbscan else 0)
n_noise = list(y_pred_dbscan).count(-1)
print(f"Number of clusters: {n_clusters}")
print(f"Number of noise points: {n_noise}")
# Visualization
plt.figure(figsize=(10, 6))
unique_labels = set(y_pred_dbscan)
colors = [plt.cm.Spectral(each) for each in np.linspace(0, 1, len(unique_labels))]
for k, col in zip(unique_labels, colors):
if k == -1:
col = [0, 0, 0, 1] # Black for noise
class_member_mask = (y_pred_dbscan == k)
xy = X[class_member_mask]
plt.plot(xy[:, 0], xy[:, 1], 'o', markerfacecolor=tuple(col),
markeredgecolor='k', markersize=6)
plt.title(f'DBSCAN Clustering\n{n_clusters} clusters, {n_noise} noise points')
plt.show()
# Grid search for optimal parameters
from sklearn.model_selection import ParameterGrid
param_grid = {
'eps': [0.3, 0.5, 0.7, 1.0],
'min_samples': [3, 5, 10]
}
best_score = -1
best_params = None
for params in ParameterGrid(param_grid):
dbscan = DBSCAN(**params)
labels = dbscan.fit_predict(X)
# Skip if all points are noise or only one cluster
if len(set(labels)) <= 1:
continue
score = silhouette_score(X, labels)
if score > best_score:
best_score = score
best_params = params
print(f"Best parameters: {best_params}")
print(f"Best silhouette score: {best_score:.3f}")
HDBSCAN
Hierarchical DBSCAN with better parameter selection:
# pip install hdbscan
import hdbscan
# HDBSCAN
clusterer = hdbscan.HDBSCAN(min_cluster_size=5, min_samples=3)
y_pred_hdbscan = clusterer.fit_predict(X)
# Cluster probabilities
probabilities = clusterer.probabilities_
print(f"Number of clusters: {len(set(y_pred_hdbscan)) - (1 if -1 in y_pred_hdbscan else 0)}")
print(f"Noise points: {list(y_pred_hdbscan).count(-1)}")
Gaussian Mixture Models
GMM assumes data is generated from a mixture of Gaussian distributions.
from sklearn.mixture import GaussianMixture
# Gaussian Mixture Model
gmm = GaussianMixture(
n_components=4,
covariance_type='full', # 'tied', 'diag', 'spherical'
random_state=42
)
gmm.fit(X)
# Predictions (hard clustering)
y_pred_gmm = gmm.predict(X)
# Soft clustering (probabilities)
probabilities = gmm.predict_proba(X)
print("Shape of probabilities:", probabilities.shape)
# Model parameters
print(f"Means:\n{gmm.means_}")
print(f"Covariances shape: {gmm.covariances_.shape}")
print(f"Weights: {gmm.weights_}")
# Bayesian Information Criterion (BIC) for model selection
n_components_range = range(2, 11)
bic_scores = []
aic_scores = []
for n_components in n_components_range:
gmm = GaussianMixture(n_components=n_components, random_state=42)
gmm.fit(X)
bic_scores.append(gmm.bic(X))
aic_scores.append(gmm.aic(X))
plt.figure(figsize=(10, 6))
plt.plot(n_components_range, bic_scores, 'bo-', label='BIC')
plt.plot(n_components_range, aic_scores, 'rs-', label='AIC')
plt.xlabel('Number of components')
plt.ylabel('Information Criterion')
plt.title('GMM Model Selection')
plt.legend()
plt.grid(True)
plt.show()
optimal_components = n_components_range[np.argmin(bic_scores)]
print(f"Optimal number of components: {optimal_components}")
Mean Shift
Finds clusters by locating peaks in density.
from sklearn.cluster import MeanShift, estimate_bandwidth
# Estimate bandwidth
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)
# Mean Shift clustering
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)
y_pred_ms = ms.labels_
cluster_centers = ms.cluster_centers_
n_clusters = len(np.unique(y_pred_ms))
print(f"Number of clusters: {n_clusters}")
Spectral Clustering
Uses eigenvalues of similarity matrix for clustering.
from sklearn.cluster import SpectralClustering
# Spectral clustering
spectral = SpectralClustering(
n_clusters=4,
affinity='rbf', # 'nearest_neighbors', 'precomputed'
assign_labels='discretize', # 'kmeans'
random_state=42
)
y_pred_spectral = spectral.fit_predict(X)
# Custom affinity matrix
from sklearn.metrics.pairwise import rbf_kernel
affinity_matrix = rbf_kernel(X, gamma=1.0)
spectral_custom = SpectralClustering(n_clusters=4, affinity='precomputed')
y_pred_spectral_custom = spectral_custom.fit_predict(affinity_matrix)
Dimensionality Reduction
Reducing the number of features while preserving important information.
Principal Component Analysis (PCA)
PCA finds orthogonal directions of maximum variance.
Mathematical Formulation:
Maximize: Var(Xw) subject to ||w|| = 1
from sklearn.decomposition import PCA
from sklearn.datasets import load_digits
# Load high-dimensional data
digits = load_digits()
X = digits.data # 64 features (8x8 images)
y = digits.target
# PCA
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X)
# Explained variance
print(f"Explained variance ratio: {pca.explained_variance_ratio_}")
print(f"Total explained variance: {pca.explained_variance_ratio_.sum():.3f}")
# Visualization
plt.figure(figsize=(10, 6))
scatter = plt.scatter(X_pca[:, 0], X_pca[:, 1], c=y, cmap='tab10', alpha=0.6)
plt.colorbar(scatter)
plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
plt.title('PCA of Digits Dataset')
plt.show()
# Determine number of components
pca_full = PCA()
pca_full.fit(X)
# Cumulative explained variance
cumsum_var = np.cumsum(pca_full.explained_variance_ratio_)
n_components_95 = np.argmax(cumsum_var >= 0.95) + 1
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(cumsum_var) + 1), cumsum_var, 'bo-')
plt.axhline(y=0.95, color='r', linestyle='--', label='95% variance')
plt.axvline(x=n_components_95, color='g', linestyle='--',
label=f'{n_components_95} components')
plt.xlabel('Number of Components')
plt.ylabel('Cumulative Explained Variance')
plt.title('PCA Explained Variance')
plt.legend()
plt.grid(True)
plt.show()
print(f"Components needed for 95% variance: {n_components_95}")
# Incremental PCA for large datasets
from sklearn.decomposition import IncrementalPCA
ipca = IncrementalPCA(n_components=10, batch_size=100)
X_ipca = ipca.fit_transform(X)
# Kernel PCA for non-linear dimensionality reduction
from sklearn.decomposition import KernelPCA
kpca = KernelPCA(n_components=2, kernel='rbf', gamma=0.04)
X_kpca = kpca.fit_transform(X)
t-SNE
t-Distributed Stochastic Neighbor Embedding for visualization.
from sklearn.manifold import TSNE
# t-SNE
tsne = TSNE(
n_components=2,
perplexity=30,
learning_rate=200,
n_iter=1000,
random_state=42
)
X_tsne = tsne.fit_transform(X)
# Visualization
plt.figure(figsize=(10, 6))
scatter = plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y, cmap='tab10', alpha=0.6)
plt.colorbar(scatter)
plt.title('t-SNE of Digits Dataset')
plt.xlabel('t-SNE 1')
plt.ylabel('t-SNE 2')
plt.show()
# Try different perplexities
fig, axes = plt.subplots(2, 2, figsize=(15, 15))
perplexities = [5, 30, 50, 100]
for ax, perplexity in zip(axes.ravel(), perplexities):
tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
X_embedded = tsne.fit_transform(X)
scatter = ax.scatter(X_embedded[:, 0], X_embedded[:, 1], c=y,
cmap='tab10', alpha=0.6)
ax.set_title(f'Perplexity = {perplexity}')
plt.tight_layout()
plt.show()
UMAP
Uniform Manifold Approximation and Projection (faster than t-SNE).
# pip install umap-learn
import umap
# UMAP
reducer = umap.UMAP(
n_components=2,
n_neighbors=15,
min_dist=0.1,
metric='euclidean',
random_state=42
)
X_umap = reducer.fit_transform(X)
# Visualization
plt.figure(figsize=(10, 6))
scatter = plt.scatter(X_umap[:, 0], X_umap[:, 1], c=y, cmap='tab10', alpha=0.6)
plt.colorbar(scatter)
plt.title('UMAP of Digits Dataset')
plt.xlabel('UMAP 1')
plt.ylabel('UMAP 2')
plt.show()
# Compare PCA, t-SNE, and UMAP
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
methods = [
('PCA', X_pca),
('t-SNE', X_tsne),
('UMAP', X_umap)
]
for ax, (name, X_reduced) in zip(axes, methods):
scatter = ax.scatter(X_reduced[:, 0], X_reduced[:, 1], c=y,
cmap='tab10', alpha=0.6)
ax.set_title(name)
ax.set_xlabel(f'{name} 1')
ax.set_ylabel(f'{name} 2')
plt.tight_layout()
plt.show()
Linear Discriminant Analysis (LDA)
Supervised dimensionality reduction that maximizes class separability.
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
# LDA (requires labels)
lda = LinearDiscriminantAnalysis(n_components=2)
X_lda = lda.fit_transform(X, y)
# Explained variance ratio
print(f"Explained variance ratio: {lda.explained_variance_ratio_}")
# Visualization
plt.figure(figsize=(10, 6))
scatter = plt.scatter(X_lda[:, 0], X_lda[:, 1], c=y, cmap='tab10', alpha=0.6)
plt.colorbar(scatter)
plt.xlabel(f'LD1 ({lda.explained_variance_ratio_[0]:.2%})')
plt.ylabel(f'LD2 ({lda.explained_variance_ratio_[1]:.2%})')
plt.title('LDA of Digits Dataset')
plt.show()
Autoencoders
Neural network-based dimensionality reduction.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
# Define autoencoder
class Autoencoder(nn.Module):
def __init__(self, input_dim, encoding_dim):
super(Autoencoder, self).__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, encoding_dim)
)
# Decoder
self.decoder = nn.Sequential(
nn.Linear(encoding_dim, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, input_dim),
nn.Sigmoid()
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
def encode(self, x):
return self.encoder(x)
# Prepare data
X_normalized = (X - X.min()) / (X.max() - X.min())
X_tensor = torch.FloatTensor(X_normalized)
dataset = TensorDataset(X_tensor, X_tensor)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# Initialize model
input_dim = X.shape[1]
encoding_dim = 2
model = Autoencoder(input_dim, encoding_dim)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Train
n_epochs = 50
for epoch in range(n_epochs):
total_loss = 0
for batch_x, _ in dataloader:
optimizer.zero_grad()
outputs = model(batch_x)
loss = criterion(outputs, batch_x)
loss.backward()
optimizer.step()
total_loss += loss.item()
if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch+1}/{n_epochs}], Loss: {total_loss/len(dataloader):.4f}')
# Get encoded representations
model.eval()
with torch.no_grad():
X_encoded = model.encode(X_tensor).numpy()
# Visualization
plt.figure(figsize=(10, 6))
scatter = plt.scatter(X_encoded[:, 0], X_encoded[:, 1], c=y, cmap='tab10', alpha=0.6)
plt.colorbar(scatter)
plt.title('Autoencoder Dimensionality Reduction')
plt.xlabel('Encoded Dimension 1')
plt.ylabel('Encoded Dimension 2')
plt.show()
Non-negative Matrix Factorization (NMF)
Decomposes data into non-negative components.
from sklearn.decomposition import NMF
# NMF (requires non-negative data)
X_nonneg = X - X.min() + 1e-10
nmf = NMF(n_components=10, init='random', random_state=42, max_iter=500)
W = nmf.fit_transform(X_nonneg) # Coefficient matrix
H = nmf.components_ # Component matrix
print(f"Reconstruction error: {nmf.reconstruction_err_:.2f}")
print(f"W shape: {W.shape}, H shape: {H.shape}")
# Visualize components
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, ax in enumerate(axes.ravel()):
ax.imshow(H[i].reshape(8, 8), cmap='gray')
ax.set_title(f'Component {i+1}')
ax.axis('off')
plt.tight_layout()
plt.show()
Truncated SVD
Similar to PCA but works with sparse matrices.
from sklearn.decomposition import TruncatedSVD
# Truncated SVD
svd = TruncatedSVD(n_components=10, random_state=42)
X_svd = svd.fit_transform(X)
print(f"Explained variance ratio: {svd.explained_variance_ratio_}")
print(f"Total explained variance: {svd.explained_variance_ratio_.sum():.3f}")
Anomaly Detection
Identifying unusual patterns that don't conform to expected behavior.
Isolation Forest
from sklearn.ensemble import IsolationForest
# Isolation Forest
iso_forest = IsolationForest(
n_estimators=100,
contamination=0.1, # Expected proportion of outliers
random_state=42
)
y_pred_outliers = iso_forest.fit_predict(X)
# -1 for outliers, 1 for inliers
n_outliers = (y_pred_outliers == -1).sum()
print(f"Number of outliers detected: {n_outliers}")
# Anomaly scores
anomaly_scores = iso_forest.score_samples(X)
# Visualization
plt.figure(figsize=(10, 6))
scatter = plt.scatter(X_pca[:, 0], X_pca[:, 1], c=anomaly_scores, cmap='RdYlGn')
plt.colorbar(scatter, label='Anomaly Score')
plt.title('Isolation Forest Anomaly Scores')
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.show()
Local Outlier Factor (LOF)
from sklearn.neighbors import LocalOutlierFactor
# Local Outlier Factor
lof = LocalOutlierFactor(n_neighbors=20, contamination=0.1)
y_pred_lof = lof.fit_predict(X)
# Negative outlier factor (lower values = more anomalous)
outlier_scores = lof.negative_outlier_factor_
n_outliers = (y_pred_lof == -1).sum()
print(f"Number of outliers detected: {n_outliers}")
One-Class SVM
from sklearn.svm import OneClassSVM
# One-Class SVM
oc_svm = OneClassSVM(nu=0.1, kernel='rbf', gamma='auto')
y_pred_oc = oc_svm.fit_predict(X)
n_outliers = (y_pred_oc == -1).sum()
print(f"Number of outliers detected: {n_outliers}")
Elliptic Envelope
from sklearn.covariance import EllipticEnvelope
# Elliptic Envelope (assumes Gaussian distribution)
elliptic = EllipticEnvelope(contamination=0.1, random_state=42)
y_pred_elliptic = elliptic.fit_predict(X)
n_outliers = (y_pred_elliptic == -1).sum()
print(f"Number of outliers detected: {n_outliers}")
Density Estimation
Estimating the probability density function of data.
Kernel Density Estimation
from sklearn.neighbors import KernelDensity
# Kernel Density Estimation
kde = KernelDensity(kernel='gaussian', bandwidth=0.5)
kde.fit(X)
# Score samples (log-likelihood)
log_density = kde.score_samples(X)
# Sample from the learned distribution
samples = kde.sample(100, random_state=42)
# Visualization (for 2D data)
if X.shape[1] == 2:
xx, yy = np.meshgrid(np.linspace(X[:, 0].min(), X[:, 0].max(), 100),
np.linspace(X[:, 1].min(), X[:, 1].max(), 100))
Z = np.exp(kde.score_samples(np.c_[xx.ravel(), yy.ravel()]))
Z = Z.reshape(xx.shape)
plt.figure(figsize=(10, 6))
plt.contourf(xx, yy, Z, levels=20, cmap='viridis', alpha=0.6)
plt.scatter(X[:, 0], X[:, 1], c='red', alpha=0.3, s=10)
plt.colorbar(label='Density')
plt.title('Kernel Density Estimation')
plt.show()
Association Rules
Finding interesting relationships between variables.
Apriori Algorithm
# pip install mlxtend
from mlxtend.frequent_patterns import apriori, association_rules
import pandas as pd
# Example transaction data
transactions = [
['milk', 'bread', 'butter'],
['milk', 'bread'],
['milk', 'butter'],
['bread', 'butter'],
['milk', 'bread', 'butter', 'cheese'],
['milk', 'cheese'],
['bread', 'cheese']
]
# Convert to one-hot encoded DataFrame
from mlxtend.preprocessing import TransactionEncoder
te = TransactionEncoder()
te_array = te.fit(transactions).transform(transactions)
df = pd.DataFrame(te_array, columns=te.columns_)
# Find frequent itemsets
frequent_itemsets = apriori(df, min_support=0.3, use_colnames=True)
print("Frequent Itemsets:")
print(frequent_itemsets)
# Generate association rules
rules = association_rules(frequent_itemsets, metric='confidence', min_threshold=0.5)
print("\nAssociation Rules:")
print(rules[['antecedents', 'consequents', 'support', 'confidence', 'lift']])
# Filter interesting rules
interesting_rules = rules[(rules['lift'] > 1) & (rules['confidence'] > 0.6)]
print("\nInteresting Rules:")
print(interesting_rules)
Clustering Evaluation Metrics
from sklearn.metrics import (
silhouette_score, davies_bouldin_score,
calinski_harabasz_score, adjusted_rand_score
)
# Silhouette Score (higher is better, range: [-1, 1])
silhouette = silhouette_score(X, y_pred)
# Davies-Bouldin Index (lower is better)
davies_bouldin = davies_bouldin_score(X, y_pred)
# Calinski-Harabasz Index (higher is better)
calinski_harabasz = calinski_harabasz_score(X, y_pred)
# Adjusted Rand Index (if true labels available)
ari = adjusted_rand_score(y_true, y_pred)
print(f"Silhouette Score: {silhouette:.3f}")
print(f"Davies-Bouldin Index: {davies_bouldin:.3f}")
print(f"Calinski-Harabasz Index: {calinski_harabasz:.3f}")
print(f"Adjusted Rand Index: {ari:.3f}")
# Silhouette analysis per sample
from sklearn.metrics import silhouette_samples
silhouette_vals = silhouette_samples(X, y_pred)
# Visualize silhouette scores
fig, ax = plt.subplots(figsize=(10, 6))
y_lower = 10
for i in range(len(set(y_pred))):
cluster_silhouette_vals = silhouette_vals[y_pred == i]
cluster_silhouette_vals.sort()
size_cluster_i = cluster_silhouette_vals.shape[0]
y_upper = y_lower + size_cluster_i
ax.fill_betweenx(np.arange(y_lower, y_upper),
0, cluster_silhouette_vals,
alpha=0.7)
ax.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
y_lower = y_upper + 10
ax.set_xlabel("Silhouette Coefficient")
ax.set_ylabel("Cluster")
ax.axvline(x=silhouette, color="red", linestyle="--")
ax.set_title("Silhouette Analysis")
plt.show()
Practical Tips
1. Feature Scaling
# Always scale features for distance-based methods
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
2. Handling High-Dimensional Data
# Apply dimensionality reduction before clustering
pca = PCA(n_components=0.95) # Keep 95% variance
X_reduced = pca.fit_transform(X_scaled)
kmeans = KMeans(n_clusters=4)
kmeans.fit(X_reduced)
3. Visualizing Clusters
def plot_clusters_3d(X, labels, title='3D Cluster Visualization'):
from mpl_toolkits.mplot3d import Axes3D
# Reduce to 3D if needed
if X.shape[1] > 3:
pca = PCA(n_components=3)
X = pca.fit_transform(X)
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection='3d')
scatter = ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=labels, cmap='viridis')
ax.set_title(title)
plt.colorbar(scatter)
plt.show()
Resources
- scikit-learn documentation: https://scikit-learn.org/
- "Pattern Recognition and Machine Learning" by Christopher Bishop
- "Introduction to Data Mining" by Tan, Steinbach, Kumar
- UMAP documentation: https://umap-learn.readthedocs.io/
Reinforcement Learning
Reinforcement Learning (RL) is about learning to make decisions by interacting with an environment to maximize cumulative reward.
Table of Contents
- Core Concepts
- Markov Decision Processes
- Dynamic Programming
- Monte Carlo Methods
- Temporal Difference Learning
- Q-Learning
- SARSA
- Policy Gradient Methods
- Actor-Critic Methods
- Multi-Armed Bandits
Core Concepts
The RL Framework
Key Components:
- Agent: The learner/decision maker
- Environment: What the agent interacts with
- State (s): Current situation
- Action (a): What the agent can do
- Reward (r): Feedback signal
- Policy (π): Strategy for selecting actions
- Value Function (V): Expected future reward from a state
- Q-Function (Q): Expected future reward for state-action pairs
Mathematical Framework:
At each time step t:
- Agent observes state s_t
- Agent takes action a_t
- Environment transitions to s_{t+1}
- Agent receives reward r_{t+1}
import numpy as np
import matplotlib.pyplot as plt
# Simple RL environment example
class GridWorld:
def __init__(self, size=5):
self.size = size
self.state = (0, 0)
self.goal = (size-1, size-1)
def reset(self):
self.state = (0, 0)
return self.state
def step(self, action):
# Actions: 0=up, 1=right, 2=down, 3=left
x, y = self.state
if action == 0: # up
x = max(0, x - 1)
elif action == 1: # right
y = min(self.size - 1, y + 1)
elif action == 2: # down
x = min(self.size - 1, x + 1)
elif action == 3: # left
y = max(0, y - 1)
self.state = (x, y)
# Reward
if self.state == self.goal:
reward = 1.0
done = True
else:
reward = -0.01 # Small penalty for each step
done = False
return self.state, reward, done
def render(self):
grid = np.zeros((self.size, self.size))
grid[self.state] = 1
grid[self.goal] = 0.5
plt.imshow(grid, cmap='hot')
plt.title(f'State: {self.state}')
plt.show()
# Example usage
env = GridWorld(size=5)
state = env.reset()
print(f"Initial state: {state}")
# Take random actions
for _ in range(5):
action = np.random.randint(0, 4)
state, reward, done = env.step(action)
print(f"State: {state}, Reward: {reward}, Done: {done}")
if done:
break
Return and Discounting
Return (G_t): Total cumulative reward from time t
G_t = R_{t+1} + γR_{t+2} + γ²R_{t+3} + ... = Σ_{k=0}^∞ γ^k R_{t+k+1}
Where γ (gamma) is the discount factor (0 ≤ γ ≤ 1):
- γ = 0: Only immediate rewards matter
- γ = 1: All future rewards equally important
- γ closer to 1: More far-sighted agent
def calculate_return(rewards, gamma=0.99):
"""Calculate discounted return from a list of rewards"""
G = 0
returns = []
for r in reversed(rewards):
G = r + gamma * G
returns.insert(0, G)
return returns
# Example
rewards = [1, 0, 0, 1, 0]
returns = calculate_return(rewards, gamma=0.9)
print(f"Rewards: {rewards}")
print(f"Returns: {returns}")
Markov Decision Processes
An MDP is defined by a tuple (S, A, P, R, γ):
- S: Set of states
- A: Set of actions
- P: Transition probability P(s'|s,a)
- R: Reward function R(s,a,s')
- γ: Discount factor
Markov Property: Future depends only on current state, not history
P(s_{t+1}|s_t, a_t, s_{t-1}, a_{t-1}, ...) = P(s_{t+1}|s_t, a_t)
class MDP:
def __init__(self, states, actions, transitions, rewards, gamma=0.99):
self.states = states
self.actions = actions
self.transitions = transitions # P(s'|s,a)
self.rewards = rewards # R(s,a,s')
self.gamma = gamma
def get_transition_prob(self, state, action, next_state):
return self.transitions.get((state, action, next_state), 0.0)
def get_reward(self, state, action, next_state):
return self.rewards.get((state, action, next_state), 0.0)
# Example: Simple MDP
states = ['s0', 's1', 's2']
actions = ['a0', 'a1']
transitions = {
('s0', 'a0', 's1'): 0.8,
('s0', 'a0', 's0'): 0.2,
('s0', 'a1', 's2'): 0.9,
('s0', 'a1', 's0'): 0.1,
('s1', 'a0', 's2'): 1.0,
('s2', 'a0', 's2'): 1.0,
}
rewards = {
('s0', 'a0', 's1'): -1,
('s0', 'a1', 's2'): 10,
('s1', 'a0', 's2'): 5,
}
mdp = MDP(states, actions, transitions, rewards)
Value Functions
State-Value Function V^π(s):
V^π(s) = E_π[G_t | S_t = s]
= E_π[Σ_{k=0}^∞ γ^k R_{t+k+1} | S_t = s]
Action-Value Function Q^π(s,a):
Q^π(s,a) = E_π[G_t | S_t = s, A_t = a]
Bellman Equations:
V^π(s) = Σ_a π(a|s) Σ_{s',r} p(s',r|s,a)[r + γV^π(s')]
Q^π(s,a) = Σ_{s',r} p(s',r|s,a)[r + γΣ_{a'} π(a'|s')Q^π(s',a')]
Optimal Value Functions:
V*(s) = max_π V^π(s) = max_a Q*(s,a)
Q*(s,a) = E[R_{t+1} + γV*(S_{t+1}) | S_t=s, A_t=a]
Dynamic Programming
DP methods assume full knowledge of the MDP.
Policy Evaluation
Compute value function for a given policy.
def policy_evaluation(policy, mdp, theta=1e-6):
"""
Evaluate a policy using iterative policy evaluation
Args:
policy: dict mapping states to action probabilities
mdp: MDP object
theta: convergence threshold
"""
V = {s: 0 for s in mdp.states}
while True:
delta = 0
for s in mdp.states:
v = V[s]
new_v = 0
# Sum over actions
for a in mdp.actions:
action_prob = policy.get((s, a), 0)
# Sum over next states
for s_prime in mdp.states:
trans_prob = mdp.get_transition_prob(s, a, s_prime)
reward = mdp.get_reward(s, a, s_prime)
new_v += action_prob * trans_prob * (reward + mdp.gamma * V[s_prime])
V[s] = new_v
delta = max(delta, abs(v - V[s]))
if delta < theta:
break
return V
# Example: Uniform random policy
random_policy = {
('s0', 'a0'): 0.5,
('s0', 'a1'): 0.5,
('s1', 'a0'): 1.0,
('s2', 'a0'): 1.0,
}
V = policy_evaluation(random_policy, mdp)
print("State values:", V)
Policy Iteration
def policy_iteration(mdp, theta=1e-6):
"""
Find optimal policy using policy iteration
"""
# Initialize random policy
policy = {}
for s in mdp.states:
action = np.random.choice(mdp.actions)
for a in mdp.actions:
policy[(s, a)] = 1.0 if a == action else 0.0
while True:
# Policy Evaluation
V = policy_evaluation(policy, mdp, theta)
# Policy Improvement
policy_stable = True
for s in mdp.states:
old_action = None
for a in mdp.actions:
if policy.get((s, a), 0) == 1.0:
old_action = a
break
# Find best action
action_values = {}
for a in mdp.actions:
q = 0
for s_prime in mdp.states:
trans_prob = mdp.get_transition_prob(s, a, s_prime)
reward = mdp.get_reward(s, a, s_prime)
q += trans_prob * (reward + mdp.gamma * V[s_prime])
action_values[a] = q
best_action = max(action_values, key=action_values.get)
# Update policy
for a in mdp.actions:
policy[(s, a)] = 1.0 if a == best_action else 0.0
if best_action != old_action:
policy_stable = False
if policy_stable:
break
return policy, V
optimal_policy, optimal_V = policy_iteration(mdp)
print("Optimal policy:", optimal_policy)
print("Optimal values:", optimal_V)
Value Iteration
def value_iteration(mdp, theta=1e-6):
"""
Find optimal policy using value iteration
"""
V = {s: 0 for s in mdp.states}
while True:
delta = 0
for s in mdp.states:
v = V[s]
# Find max over actions
action_values = []
for a in mdp.actions:
q = 0
for s_prime in mdp.states:
trans_prob = mdp.get_transition_prob(s, a, s_prime)
reward = mdp.get_reward(s, a, s_prime)
q += trans_prob * (reward + mdp.gamma * V[s_prime])
action_values.append(q)
V[s] = max(action_values) if action_values else 0
delta = max(delta, abs(v - V[s]))
if delta < theta:
break
# Extract policy
policy = {}
for s in mdp.states:
action_values = {}
for a in mdp.actions:
q = 0
for s_prime in mdp.states:
trans_prob = mdp.get_transition_prob(s, a, s_prime)
reward = mdp.get_reward(s, a, s_prime)
q += trans_prob * (reward + mdp.gamma * V[s_prime])
action_values[a] = q
best_action = max(action_values, key=action_values.get)
for a in mdp.actions:
policy[(s, a)] = 1.0 if a == best_action else 0.0
return policy, V
optimal_policy, optimal_V = value_iteration(mdp)
Monte Carlo Methods
MC methods learn from complete episodes without needing environment model.
First-Visit MC Prediction
def first_visit_mc_prediction(env, policy, num_episodes=1000, gamma=0.99):
"""
Estimate state-value function using first-visit MC
"""
returns = {s: [] for s in env.states}
V = {s: 0 for s in env.states}
for episode in range(num_episodes):
# Generate episode
episode_data = []
state = env.reset()
done = False
while not done:
action = policy[state]
next_state, reward, done = env.step(action)
episode_data.append((state, action, reward))
state = next_state
# Calculate returns
G = 0
visited_states = set()
for t in reversed(range(len(episode_data))):
state, action, reward = episode_data[t]
G = reward + gamma * G
# First-visit: only update if state not seen earlier
if state not in visited_states:
returns[state].append(G)
V[state] = np.mean(returns[state])
visited_states.add(state)
return V
# Example usage with GridWorld
env = GridWorld(size=4)
# Define a simple policy
policy = {state: np.random.randint(0, 4) for state in
[(i, j) for i in range(4) for j in range(4)]}
V = first_visit_mc_prediction(env, policy, num_episodes=10000)
Monte Carlo Control (Epsilon-Greedy)
def mc_control_epsilon_greedy(env, num_episodes=10000, gamma=0.99, epsilon=0.1):
"""
Monte Carlo control with epsilon-greedy policy
"""
Q = {}
returns = {}
# Initialize Q-values
for state in env.get_all_states():
for action in range(env.num_actions):
Q[(state, action)] = 0
returns[(state, action)] = []
for episode in range(num_episodes):
# Generate episode with epsilon-greedy policy
episode_data = []
state = env.reset()
done = False
while not done:
# Epsilon-greedy action selection
if np.random.random() < epsilon:
action = np.random.randint(0, env.num_actions)
else:
q_values = [Q.get((state, a), 0) for a in range(env.num_actions)]
action = np.argmax(q_values)
next_state, reward, done = env.step(action)
episode_data.append((state, action, reward))
state = next_state
# Update Q-values
G = 0
visited = set()
for t in reversed(range(len(episode_data))):
state, action, reward = episode_data[t]
G = reward + gamma * G
if (state, action) not in visited:
returns[(state, action)].append(G)
Q[(state, action)] = np.mean(returns[(state, action)])
visited.add((state, action))
# Extract policy
policy = {}
for state in env.get_all_states():
q_values = [Q.get((state, a), 0) for a in range(env.num_actions)]
policy[state] = np.argmax(q_values)
return policy, Q
Temporal Difference Learning
TD methods learn from incomplete episodes by bootstrapping.
TD(0) Prediction
TD Update Rule:
V(S_t) ← V(S_t) + α[R_{t+1} + γV(S_{t+1}) - V(S_t)]
Where:
- α is the learning rate
- R_{t+1} + γV(S_{t+1}) is the TD target
- δ_t = R_{t+1} + γV(S_{t+1}) - V(S_t) is the TD error
def td_0_prediction(env, policy, num_episodes=1000, alpha=0.1, gamma=0.99):
"""
TD(0) prediction for estimating state values
"""
V = {s: 0 for s in env.get_all_states()}
for episode in range(num_episodes):
state = env.reset()
done = False
while not done:
action = policy[state]
next_state, reward, done = env.step(action)
# TD update
if not done:
td_target = reward + gamma * V[next_state]
else:
td_target = reward
td_error = td_target - V[state]
V[state] += alpha * td_error
state = next_state
return V
TD(λ) - Eligibility Traces
def td_lambda_prediction(env, policy, num_episodes=1000,
alpha=0.1, gamma=0.99, lambda_=0.9):
"""
TD(λ) prediction with eligibility traces
"""
V = {s: 0 for s in env.get_all_states()}
for episode in range(num_episodes):
E = {s: 0 for s in env.get_all_states()} # Eligibility traces
state = env.reset()
done = False
while not done:
action = policy[state]
next_state, reward, done = env.step(action)
# Calculate TD error
if not done:
td_error = reward + gamma * V[next_state] - V[state]
else:
td_error = reward - V[state]
# Update eligibility trace for current state
E[state] += 1
# Update all states
for s in env.get_all_states():
V[s] += alpha * td_error * E[s]
E[s] *= gamma * lambda_
state = next_state
return V
Q-Learning
Q-Learning is an off-policy TD control algorithm.
Q-Learning Update:
Q(S_t, A_t) ← Q(S_t, A_t) + α[R_{t+1} + γ max_a Q(S_{t+1}, a) - Q(S_t, A_t)]
class QLearningAgent:
def __init__(self, state_space, action_space, alpha=0.1, gamma=0.99, epsilon=0.1):
self.state_space = state_space
self.action_space = action_space
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
# Initialize Q-table
self.q_table = {}
for state in state_space:
for action in action_space:
self.q_table[(state, action)] = 0.0
def get_action(self, state, training=True):
"""Epsilon-greedy action selection"""
if training and np.random.random() < self.epsilon:
return np.random.choice(self.action_space)
else:
q_values = [self.q_table.get((state, a), 0) for a in self.action_space]
return self.action_space[np.argmax(q_values)]
def update(self, state, action, reward, next_state, done):
"""Q-learning update"""
# Current Q-value
current_q = self.q_table[(state, action)]
# Maximum Q-value for next state
if not done:
max_next_q = max([self.q_table.get((next_state, a), 0)
for a in self.action_space])
else:
max_next_q = 0
# Q-learning update
td_target = reward + self.gamma * max_next_q
td_error = td_target - current_q
self.q_table[(state, action)] += self.alpha * td_error
return td_error
def train(self, env, num_episodes=1000):
"""Train the agent"""
episode_rewards = []
for episode in range(num_episodes):
state = env.reset()
total_reward = 0
done = False
while not done:
action = self.get_action(state, training=True)
next_state, reward, done = env.step(action)
self.update(state, action, reward, next_state, done)
state = next_state
total_reward += reward
episode_rewards.append(total_reward)
# Decay epsilon
self.epsilon = max(0.01, self.epsilon * 0.995)
if (episode + 1) % 100 == 0:
avg_reward = np.mean(episode_rewards[-100:])
print(f"Episode {episode + 1}, Avg Reward: {avg_reward:.2f}, Epsilon: {self.epsilon:.3f}")
return episode_rewards
# Example usage
env = GridWorld(size=5)
state_space = [(i, j) for i in range(5) for j in range(5)]
action_space = [0, 1, 2, 3] # up, right, down, left
agent = QLearningAgent(state_space, action_space)
rewards = agent.train(env, num_episodes=5000)
# Plot learning curve
plt.figure(figsize=(10, 6))
plt.plot(np.convolve(rewards, np.ones(100)/100, mode='valid'))
plt.xlabel('Episode')
plt.ylabel('Average Reward (100 episodes)')
plt.title('Q-Learning Training Progress')
plt.grid(True)
plt.show()
Double Q-Learning
Reduces maximization bias in Q-learning.
class DoubleQLearningAgent:
def __init__(self, state_space, action_space, alpha=0.1, gamma=0.99, epsilon=0.1):
self.state_space = state_space
self.action_space = action_space
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
# Two Q-tables
self.q_table_1 = {(s, a): 0.0 for s in state_space for a in action_space}
self.q_table_2 = {(s, a): 0.0 for s in state_space for a in action_space}
def get_action(self, state, training=True):
"""Epsilon-greedy using average of both Q-tables"""
if training and np.random.random() < self.epsilon:
return np.random.choice(self.action_space)
else:
q_values = [(self.q_table_1[(state, a)] + self.q_table_2[(state, a)]) / 2
for a in self.action_space]
return self.action_space[np.argmax(q_values)]
def update(self, state, action, reward, next_state, done):
"""Double Q-learning update"""
# Randomly choose which Q-table to update
if np.random.random() < 0.5:
q_table_update = self.q_table_1
q_table_target = self.q_table_2
else:
q_table_update = self.q_table_2
q_table_target = self.q_table_1
current_q = q_table_update[(state, action)]
if not done:
# Use one Q-table to select action, other to evaluate
best_action = max(self.action_space,
key=lambda a: q_table_update[(next_state, a)])
max_next_q = q_table_target[(next_state, best_action)]
else:
max_next_q = 0
td_target = reward + self.gamma * max_next_q
td_error = td_target - current_q
q_table_update[(state, action)] += self.alpha * td_error
return td_error
SARSA
SARSA is an on-policy TD control algorithm.
SARSA Update:
Q(S_t, A_t) ← Q(S_t, A_t) + α[R_{t+1} + γQ(S_{t+1}, A_{t+1}) - Q(S_t, A_t)]
class SARSAAgent:
def __init__(self, state_space, action_space, alpha=0.1, gamma=0.99, epsilon=0.1):
self.state_space = state_space
self.action_space = action_space
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
# Initialize Q-table
self.q_table = {(s, a): 0.0 for s in state_space for a in action_space}
def get_action(self, state, training=True):
"""Epsilon-greedy action selection"""
if training and np.random.random() < self.epsilon:
return np.random.choice(self.action_space)
else:
q_values = [self.q_table[(state, a)] for a in self.action_space]
return self.action_space[np.argmax(q_values)]
def update(self, state, action, reward, next_state, next_action, done):
"""SARSA update"""
current_q = self.q_table[(state, action)]
if not done:
next_q = self.q_table[(next_state, next_action)]
else:
next_q = 0
td_target = reward + self.gamma * next_q
td_error = td_target - current_q
self.q_table[(state, action)] += self.alpha * td_error
return td_error
def train(self, env, num_episodes=1000):
"""Train the agent"""
episode_rewards = []
for episode in range(num_episodes):
state = env.reset()
action = self.get_action(state, training=True)
total_reward = 0
done = False
while not done:
next_state, reward, done = env.step(action)
next_action = self.get_action(next_state, training=True)
self.update(state, action, reward, next_state, next_action, done)
state = next_state
action = next_action
total_reward += reward
episode_rewards.append(total_reward)
self.epsilon = max(0.01, self.epsilon * 0.995)
return episode_rewards
Policy Gradient Methods
Policy gradient methods directly optimize the policy.
REINFORCE Algorithm
Policy Gradient Theorem:
∇_θ J(θ) = E_π[∇_θ log π(a|s,θ) Q^π(s,a)]
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
class PolicyNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=64):
super(PolicyNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, action_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.softmax(self.fc3(x), dim=-1)
return x
class REINFORCEAgent:
def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99):
self.gamma = gamma
self.policy_net = PolicyNetwork(state_dim, action_dim)
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
self.saved_log_probs = []
self.rewards = []
def select_action(self, state):
"""Select action using current policy"""
state = torch.FloatTensor(state).unsqueeze(0)
probs = self.policy_net(state)
action_dist = torch.distributions.Categorical(probs)
action = action_dist.sample()
# Save log probability for training
self.saved_log_probs.append(action_dist.log_prob(action))
return action.item()
def update(self):
"""Update policy using REINFORCE"""
R = 0
returns = []
# Calculate returns
for r in reversed(self.rewards):
R = r + self.gamma * R
returns.insert(0, R)
# Normalize returns
returns = torch.tensor(returns)
returns = (returns - returns.mean()) / (returns.std() + 1e-8)
# Calculate loss
policy_loss = []
for log_prob, R in zip(self.saved_log_probs, returns):
policy_loss.append(-log_prob * R)
# Update policy
self.optimizer.zero_grad()
policy_loss = torch.stack(policy_loss).sum()
policy_loss.backward()
self.optimizer.step()
# Clear saved values
self.saved_log_probs = []
self.rewards = []
return policy_loss.item()
def train(self, env, num_episodes=1000):
"""Train the agent"""
episode_rewards = []
for episode in range(num_episodes):
state = env.reset()
total_reward = 0
done = False
while not done:
action = self.select_action(state)
next_state, reward, done = env.step(action)
self.rewards.append(reward)
state = next_state
total_reward += reward
# Update policy after episode
loss = self.update()
episode_rewards.append(total_reward)
if (episode + 1) % 100 == 0:
avg_reward = np.mean(episode_rewards[-100:])
print(f"Episode {episode + 1}, Avg Reward: {avg_reward:.2f}")
return episode_rewards
Actor-Critic Methods
Combine value-based and policy-based methods.
Advantage Actor-Critic (A2C)
class ActorCritic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=64):
super(ActorCritic, self).__init__()
self.shared = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU()
)
# Actor (policy)
self.actor = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Softmax(dim=-1)
)
# Critic (value function)
self.critic = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x):
shared_features = self.shared(x)
action_probs = self.actor(shared_features)
state_value = self.critic(shared_features)
return action_probs, state_value
class A2CAgent:
def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99):
self.gamma = gamma
self.model = ActorCritic(state_dim, action_dim)
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
def select_action(self, state):
"""Select action and get state value"""
state = torch.FloatTensor(state).unsqueeze(0)
action_probs, state_value = self.model(state)
action_dist = torch.distributions.Categorical(action_probs)
action = action_dist.sample()
return action.item(), action_dist.log_prob(action), state_value
def train_step(self, log_prob, value, reward, next_value, done):
"""Single training step"""
# Calculate advantage
if done:
td_target = reward
else:
td_target = reward + self.gamma * next_value
advantage = td_target - value
# Actor loss (policy gradient)
actor_loss = -log_prob * advantage.detach()
# Critic loss (value function)
critic_loss = F.mse_loss(value, torch.tensor([td_target]))
# Total loss
loss = actor_loss + critic_loss
# Update
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
Multi-Armed Bandits
Simplified RL problem with one state.
Epsilon-Greedy Bandit
class EpsilonGreedyBandit:
def __init__(self, n_arms, epsilon=0.1):
self.n_arms = n_arms
self.epsilon = epsilon
self.q_values = np.zeros(n_arms) # Estimated action values
self.action_counts = np.zeros(n_arms) # Number of times each action selected
def select_action(self):
"""Epsilon-greedy action selection"""
if np.random.random() < self.epsilon:
return np.random.randint(self.n_arms)
else:
return np.argmax(self.q_values)
def update(self, action, reward):
"""Update Q-value estimate"""
self.action_counts[action] += 1
alpha = 1 / self.action_counts[action]
self.q_values[action] += alpha * (reward - self.q_values[action])
# Test bandit
true_rewards = [0.1, 0.5, 0.3, 0.7, 0.2]
bandit = EpsilonGreedyBandit(n_arms=5, epsilon=0.1)
total_reward = 0
for t in range(1000):
action = bandit.select_action()
reward = true_rewards[action] + np.random.normal(0, 0.1)
bandit.update(action, reward)
total_reward += reward
print(f"True rewards: {true_rewards}")
print(f"Estimated rewards: {bandit.q_values}")
print(f"Total reward: {total_reward:.2f}")
Upper Confidence Bound (UCB)
class UCBBandit:
def __init__(self, n_arms, c=2):
self.n_arms = n_arms
self.c = c
self.q_values = np.zeros(n_arms)
self.action_counts = np.zeros(n_arms)
self.t = 0
def select_action(self):
"""UCB action selection"""
self.t += 1
# Select each arm at least once
if 0 in self.action_counts:
return np.argmin(self.action_counts)
# UCB formula
ucb_values = self.q_values + self.c * np.sqrt(np.log(self.t) / self.action_counts)
return np.argmax(ucb_values)
def update(self, action, reward):
"""Update Q-value estimate"""
self.action_counts[action] += 1
alpha = 1 / self.action_counts[action]
self.q_values[action] += alpha * (reward - self.q_values[action])
Practical Tips
- Start Simple: Begin with simple environments and algorithms
- Hyperparameter Tuning: Learning rate, discount factor, and exploration rate are crucial
- Experience Replay: Store and replay past experiences (covered in deep RL)
- Reward Shaping: Design rewards carefully to guide learning
- Exploration vs Exploitation: Balance is key for good performance
- Curriculum Learning: Start with easy tasks and gradually increase difficulty
Resources
- "Reinforcement Learning: An Introduction" by Sutton and Barto
- OpenAI Gym: https://gym.openai.com/
- Stable Baselines3: https://stable-baselines3.readthedocs.io/
- David Silver's RL Course: https://www.davidsilver.uk/teaching/
Deep Reinforcement Learning
Deep RL combines deep learning with reinforcement learning to handle high-dimensional state and action spaces.
Table of Contents
- Deep Q-Networks (DQN)
- Policy Gradient Methods
- Actor-Critic Methods
- A3C (Asynchronous Advantage Actor-Critic)
- PPO (Proximal Policy Optimization)
- DDPG (Deep Deterministic Policy Gradient)
- SAC (Soft Actor-Critic)
- TD3 (Twin Delayed DDPG)
- Model-Based RL
- Multi-Agent RL
Deep Q-Networks (DQN)
DQN uses deep neural networks to approximate Q-values for high-dimensional states.
Basic DQN
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from collections import deque, namedtuple
import random
# Q-Network
class QNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dims=[128, 128]):
super(QNetwork, self).__init__()
layers = []
prev_dim = state_dim
for hidden_dim in hidden_dims:
layers.append(nn.Linear(prev_dim, hidden_dim))
layers.append(nn.ReLU())
prev_dim = hidden_dim
layers.append(nn.Linear(prev_dim, action_dim))
self.network = nn.Sequential(*layers)
def forward(self, state):
return self.network(state)
# Experience Replay Buffer
class ReplayBuffer:
def __init__(self, capacity=100000):
self.buffer = deque(maxlen=capacity)
self.experience = namedtuple('Experience',
['state', 'action', 'reward', 'next_state', 'done'])
def push(self, state, action, reward, next_state, done):
self.buffer.append(self.experience(state, action, reward, next_state, done))
def sample(self, batch_size):
experiences = random.sample(self.buffer, batch_size)
states = torch.FloatTensor([e.state for e in experiences])
actions = torch.LongTensor([e.action for e in experiences])
rewards = torch.FloatTensor([e.reward for e in experiences])
next_states = torch.FloatTensor([e.next_state for e in experiences])
dones = torch.FloatTensor([e.done for e in experiences])
return states, actions, rewards, next_states, dones
def __len__(self):
return len(self.buffer)
# DQN Agent
class DQNAgent:
def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99,
epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01,
buffer_size=100000, batch_size=64, target_update_freq=10):
self.state_dim = state_dim
self.action_dim = action_dim
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_decay = epsilon_decay
self.epsilon_min = epsilon_min
self.batch_size = batch_size
self.target_update_freq = target_update_freq
# Q-Networks
self.q_network = QNetwork(state_dim, action_dim)
self.target_network = QNetwork(state_dim, action_dim)
self.target_network.load_state_dict(self.q_network.state_dict())
self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
self.replay_buffer = ReplayBuffer(buffer_size)
self.steps = 0
def select_action(self, state, training=True):
"""Epsilon-greedy action selection"""
if training and random.random() < self.epsilon:
return random.randint(0, self.action_dim - 1)
with torch.no_grad():
state = torch.FloatTensor(state).unsqueeze(0)
q_values = self.q_network(state)
return q_values.argmax(1).item()
def train_step(self):
"""Single training step"""
if len(self.replay_buffer) < self.batch_size:
return None
# Sample batch
states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
# Current Q values
current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze()
# Target Q values
with torch.no_grad():
next_q_values = self.target_network(next_states).max(1)[0]
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
# Compute loss
loss = F.mse_loss(current_q_values, target_q_values)
# Optimize
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 1.0)
self.optimizer.step()
# Update target network
self.steps += 1
if self.steps % self.target_update_freq == 0:
self.target_network.load_state_dict(self.q_network.state_dict())
# Decay epsilon
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
return loss.item()
def train(self, env, num_episodes=1000, max_steps=1000):
"""Train the agent"""
episode_rewards = []
for episode in range(num_episodes):
state = env.reset()
total_reward = 0
for step in range(max_steps):
# Select and perform action
action = self.select_action(state, training=True)
next_state, reward, done, _ = env.step(action)
# Store transition
self.replay_buffer.push(state, action, reward, next_state, done)
# Train
loss = self.train_step()
state = next_state
total_reward += reward
if done:
break
episode_rewards.append(total_reward)
if (episode + 1) % 10 == 0:
avg_reward = np.mean(episode_rewards[-10:])
print(f"Episode {episode+1}, Avg Reward: {avg_reward:.2f}, "
f"Epsilon: {self.epsilon:.3f}")
return episode_rewards
Double DQN
Reduces overestimation bias in Q-learning.
class DoubleDQNAgent(DQNAgent):
def train_step(self):
"""Double DQN training step"""
if len(self.replay_buffer) < self.batch_size:
return None
states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
# Current Q values
current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze()
# Double DQN: use online network to select actions, target network to evaluate
with torch.no_grad():
# Select actions using online network
next_actions = self.q_network(next_states).argmax(1)
# Evaluate using target network
next_q_values = self.target_network(next_states).gather(1, next_actions.unsqueeze(1)).squeeze()
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
# Compute loss
loss = F.mse_loss(current_q_values, target_q_values)
# Optimize
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 1.0)
self.optimizer.step()
# Update target network
self.steps += 1
if self.steps % self.target_update_freq == 0:
self.target_network.load_state_dict(self.q_network.state_dict())
self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
return loss.item()
Dueling DQN
Separates value and advantage functions.
class DuelingQNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super(DuelingQNetwork, self).__init__()
# Feature extraction
self.feature = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# Value stream
self.value_stream = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
# Advantage stream
self.advantage_stream = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim)
)
def forward(self, state):
features = self.feature(state)
value = self.value_stream(features)
advantages = self.advantage_stream(features)
# Q(s,a) = V(s) + (A(s,a) - mean(A(s,a)))
q_values = value + (advantages - advantages.mean(dim=1, keepdim=True))
return q_values
Prioritized Experience Replay
class PrioritizedReplayBuffer:
def __init__(self, capacity=100000, alpha=0.6):
self.capacity = capacity
self.alpha = alpha
self.buffer = []
self.priorities = np.zeros(capacity, dtype=np.float32)
self.position = 0
def push(self, state, action, reward, next_state, done):
max_priority = self.priorities.max() if self.buffer else 1.0
if len(self.buffer) < self.capacity:
self.buffer.append((state, action, reward, next_state, done))
else:
self.buffer[self.position] = (state, action, reward, next_state, done)
self.priorities[self.position] = max_priority
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size, beta=0.4):
if len(self.buffer) == self.capacity:
priorities = self.priorities
else:
priorities = self.priorities[:self.position]
# Calculate sampling probabilities
probabilities = priorities ** self.alpha
probabilities /= probabilities.sum()
# Sample indices
indices = np.random.choice(len(self.buffer), batch_size, p=probabilities)
# Importance sampling weights
total = len(self.buffer)
weights = (total * probabilities[indices]) ** (-beta)
weights /= weights.max()
# Get experiences
experiences = [self.buffer[idx] for idx in indices]
states = torch.FloatTensor([e[0] for e in experiences])
actions = torch.LongTensor([e[1] for e in experiences])
rewards = torch.FloatTensor([e[2] for e in experiences])
next_states = torch.FloatTensor([e[3] for e in experiences])
dones = torch.FloatTensor([e[4] for e in experiences])
weights = torch.FloatTensor(weights)
return states, actions, rewards, next_states, dones, indices, weights
def update_priorities(self, indices, priorities):
for idx, priority in zip(indices, priorities):
self.priorities[idx] = priority
def __len__(self):
return len(self.buffer)
Policy Gradient Methods
REINFORCE with Baseline
class PolicyNetwork(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super(PolicyNetwork, self).__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Softmax(dim=-1)
)
def forward(self, state):
return self.network(state)
class ValueNetwork(nn.Module):
def __init__(self, state_dim, hidden_dim=128):
super(ValueNetwork, self).__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, state):
return self.network(state)
class REINFORCEWithBaseline:
def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99):
self.gamma = gamma
self.policy_net = PolicyNetwork(state_dim, action_dim)
self.value_net = ValueNetwork(state_dim)
self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=lr)
self.saved_log_probs = []
self.saved_values = []
self.rewards = []
def select_action(self, state):
state = torch.FloatTensor(state).unsqueeze(0)
# Get action probabilities and value
probs = self.policy_net(state)
value = self.value_net(state)
# Sample action
action_dist = torch.distributions.Categorical(probs)
action = action_dist.sample()
# Save log prob and value
self.saved_log_probs.append(action_dist.log_prob(action))
self.saved_values.append(value)
return action.item()
def train_step(self):
R = 0
returns = []
# Calculate returns
for r in reversed(self.rewards):
R = r + self.gamma * R
returns.insert(0, R)
returns = torch.tensor(returns)
returns = (returns - returns.mean()) / (returns.std() + 1e-8)
policy_losses = []
value_losses = []
# Calculate losses
for log_prob, value, R in zip(self.saved_log_probs, self.saved_values, returns):
advantage = R - value.item()
# Policy loss (REINFORCE with baseline)
policy_losses.append(-log_prob * advantage)
# Value loss
value_losses.append(F.mse_loss(value, torch.tensor([[R]])))
# Update policy network
self.policy_optimizer.zero_grad()
policy_loss = torch.stack(policy_losses).sum()
policy_loss.backward()
self.policy_optimizer.step()
# Update value network
self.value_optimizer.zero_grad()
value_loss = torch.stack(value_losses).sum()
value_loss.backward()
self.value_optimizer.step()
# Clear saved values
self.saved_log_probs = []
self.saved_values = []
self.rewards = []
return policy_loss.item(), value_loss.item()
Actor-Critic Methods
Advantage Actor-Critic (A2C)
class ActorCritic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super(ActorCritic, self).__init__()
# Shared layers
self.shared = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU()
)
# Actor head
self.actor = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Softmax(dim=-1)
)
# Critic head
self.critic = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, state):
shared_features = self.shared(state)
action_probs = self.actor(shared_features)
state_value = self.critic(shared_features)
return action_probs, state_value
class A2CAgent:
def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99,
value_coef=0.5, entropy_coef=0.01):
self.gamma = gamma
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.model = ActorCritic(state_dim, action_dim)
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
def select_action(self, state):
state = torch.FloatTensor(state).unsqueeze(0)
action_probs, state_value = self.model(state)
action_dist = torch.distributions.Categorical(action_probs)
action = action_dist.sample()
return action.item(), action_dist.log_prob(action), action_dist.entropy(), state_value
def train_step(self, states, actions, rewards, next_states, dones):
"""Train on a batch of experiences"""
states = torch.FloatTensor(states)
actions = torch.LongTensor(actions)
rewards = torch.FloatTensor(rewards)
next_states = torch.FloatTensor(next_states)
dones = torch.FloatTensor(dones)
# Get action probabilities and values
action_probs, values = self.model(states)
_, next_values = self.model(next_states)
# Calculate advantages
td_targets = rewards + (1 - dones) * self.gamma * next_values.squeeze()
advantages = td_targets - values.squeeze()
# Actor loss
action_dist = torch.distributions.Categorical(action_probs)
log_probs = action_dist.log_prob(actions)
actor_loss = -(log_probs * advantages.detach()).mean()
# Critic loss
critic_loss = F.mse_loss(values.squeeze(), td_targets.detach())
# Entropy bonus (encourages exploration)
entropy = action_dist.entropy().mean()
# Total loss
loss = actor_loss + self.value_coef * critic_loss - self.entropy_coef * entropy
# Optimize
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
self.optimizer.step()
return loss.item(), actor_loss.item(), critic_loss.item(), entropy.item()
A3C
Asynchronous Advantage Actor-Critic uses multiple parallel workers.
import torch.multiprocessing as mp
class A3CAgent:
def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99):
self.state_dim = state_dim
self.action_dim = action_dim
self.gamma = gamma
# Global network (shared across workers)
self.global_model = ActorCritic(state_dim, action_dim)
self.global_model.share_memory()
self.optimizer = optim.Adam(self.global_model.parameters(), lr=lr)
def worker(self, worker_id, env_fn, num_episodes=1000):
"""Worker process for A3C"""
local_model = ActorCritic(self.state_dim, self.action_dim)
env = env_fn()
for episode in range(num_episodes):
state = env.reset()
done = False
states, actions, rewards = [], [], []
while not done:
# Sync local model with global model
local_model.load_state_dict(self.global_model.state_dict())
# Select action
state_tensor = torch.FloatTensor(state).unsqueeze(0)
action_probs, _ = local_model(state_tensor)
action_dist = torch.distributions.Categorical(action_probs)
action = action_dist.sample()
# Take action
next_state, reward, done, _ = env.step(action.item())
# Store transition
states.append(state)
actions.append(action.item())
rewards.append(reward)
state = next_state
# Update global network periodically
if len(states) >= 20 or done:
self._update_global(local_model, states, actions, rewards, next_state, done)
states, actions, rewards = [], [], []
def _update_global(self, local_model, states, actions, rewards, next_state, done):
"""Update global network using local gradients"""
states_tensor = torch.FloatTensor(states)
actions_tensor = torch.LongTensor(actions)
# Calculate returns
R = 0
if not done:
_, next_value = local_model(torch.FloatTensor(next_state).unsqueeze(0))
R = next_value.item()
returns = []
for r in reversed(rewards):
R = r + self.gamma * R
returns.insert(0, R)
returns = torch.FloatTensor(returns)
# Calculate loss
action_probs, values = local_model(states_tensor)
action_dist = torch.distributions.Categorical(action_probs)
log_probs = action_dist.log_prob(actions_tensor)
advantages = returns - values.squeeze()
actor_loss = -(log_probs * advantages.detach()).mean()
critic_loss = advantages.pow(2).mean()
entropy = action_dist.entropy().mean()
loss = actor_loss + 0.5 * critic_loss - 0.01 * entropy
# Update global network
self.optimizer.zero_grad()
loss.backward()
# Transfer gradients to global network
for local_param, global_param in zip(local_model.parameters(),
self.global_model.parameters()):
if global_param.grad is None:
global_param.grad = local_param.grad
else:
global_param.grad += local_param.grad
self.optimizer.step()
def train(self, env_fn, num_workers=4, num_episodes=1000):
"""Train using multiple parallel workers"""
processes = []
for worker_id in range(num_workers):
p = mp.Process(target=self.worker, args=(worker_id, env_fn, num_episodes))
p.start()
processes.append(p)
for p in processes:
p.join()
PPO
Proximal Policy Optimization is a policy gradient method with clipped objective.
class PPOAgent:
def __init__(self, state_dim, action_dim, lr=0.0003, gamma=0.99,
epsilon=0.2, value_coef=0.5, entropy_coef=0.01,
epochs=10, batch_size=64):
self.gamma = gamma
self.epsilon = epsilon
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.epochs = epochs
self.batch_size = batch_size
self.model = ActorCritic(state_dim, action_dim)
self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
# Storage for rollouts
self.states = []
self.actions = []
self.rewards = []
self.values = []
self.log_probs = []
self.dones = []
def select_action(self, state):
"""Select action using current policy"""
state_tensor = torch.FloatTensor(state).unsqueeze(0)
with torch.no_grad():
action_probs, value = self.model(state_tensor)
action_dist = torch.distributions.Categorical(action_probs)
action = action_dist.sample()
log_prob = action_dist.log_prob(action)
return action.item(), log_prob.item(), value.item()
def store_transition(self, state, action, reward, log_prob, value, done):
"""Store transition in buffer"""
self.states.append(state)
self.actions.append(action)
self.rewards.append(reward)
self.log_probs.append(log_prob)
self.values.append(value)
self.dones.append(done)
def compute_gae(self, next_value, gamma=0.99, lam=0.95):
"""Compute Generalized Advantage Estimation"""
advantages = []
gae = 0
values = self.values + [next_value]
for t in reversed(range(len(self.rewards))):
delta = self.rewards[t] + gamma * values[t + 1] * (1 - self.dones[t]) - values[t]
gae = delta + gamma * lam * (1 - self.dones[t]) * gae
advantages.insert(0, gae)
returns = [adv + val for adv, val in zip(advantages, self.values)]
return advantages, returns
def update(self, next_state):
"""PPO update"""
# Get next value for GAE
with torch.no_grad():
next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0)
_, next_value = self.model(next_state_tensor)
next_value = next_value.item()
# Compute advantages and returns
advantages, returns = self.compute_gae(next_value)
# Convert to tensors
states = torch.FloatTensor(self.states)
actions = torch.LongTensor(self.actions)
old_log_probs = torch.FloatTensor(self.log_probs)
advantages = torch.FloatTensor(advantages)
returns = torch.FloatTensor(returns)
# Normalize advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# PPO update for multiple epochs
for _ in range(self.epochs):
# Get current policy
action_probs, values = self.model(states)
action_dist = torch.distributions.Categorical(action_probs)
log_probs = action_dist.log_prob(actions)
entropy = action_dist.entropy()
# Ratio for clipping
ratio = torch.exp(log_probs - old_log_probs)
# Clipped surrogate objective
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * advantages
actor_loss = -torch.min(surr1, surr2).mean()
# Value loss
critic_loss = F.mse_loss(values.squeeze(), returns)
# Entropy bonus
entropy_loss = -entropy.mean()
# Total loss
loss = actor_loss + self.value_coef * critic_loss + self.entropy_coef * entropy_loss
# Update
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
self.optimizer.step()
# Clear buffers
self.states = []
self.actions = []
self.rewards = []
self.values = []
self.log_probs = []
self.dones = []
return loss.item()
DDPG
Deep Deterministic Policy Gradient for continuous action spaces.
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action, hidden_dim=256):
super(Actor, self).__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Tanh()
)
self.max_action = max_action
def forward(self, state):
return self.max_action * self.network(state)
class Critic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super(Critic, self).__init__()
self.network = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, state, action):
return self.network(torch.cat([state, action], dim=1))
class DDPGAgent:
def __init__(self, state_dim, action_dim, max_action, lr=0.001,
gamma=0.99, tau=0.005, noise_std=0.1):
self.gamma = gamma
self.tau = tau
self.noise_std = noise_std
self.max_action = max_action
# Actor networks
self.actor = Actor(state_dim, action_dim, max_action)
self.actor_target = Actor(state_dim, action_dim, max_action)
self.actor_target.load_state_dict(self.actor.state_dict())
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
# Critic networks
self.critic = Critic(state_dim, action_dim)
self.critic_target = Critic(state_dim, action_dim)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr)
self.replay_buffer = ReplayBuffer()
def select_action(self, state, training=True):
"""Select action with optional exploration noise"""
state = torch.FloatTensor(state).unsqueeze(0)
with torch.no_grad():
action = self.actor(state).cpu().numpy()[0]
if training:
noise = np.random.normal(0, self.noise_std, size=action.shape)
action = np.clip(action + noise, -self.max_action, self.max_action)
return action
def train_step(self, batch_size=64):
"""Single DDPG training step"""
if len(self.replay_buffer) < batch_size:
return None, None
# Sample batch
states, actions, rewards, next_states, dones = self.replay_buffer.sample(batch_size)
# Update critic
with torch.no_grad():
next_actions = self.actor_target(next_states)
target_q = self.critic_target(next_states, next_actions)
target_q = rewards.unsqueeze(1) + (1 - dones.unsqueeze(1)) * self.gamma * target_q
current_q = self.critic(states, actions)
critic_loss = F.mse_loss(current_q, target_q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# Update actor
actor_loss = -self.critic(states, self.actor(states)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# Soft update target networks
self._soft_update(self.actor, self.actor_target)
self._soft_update(self.critic, self.critic_target)
return actor_loss.item(), critic_loss.item()
def _soft_update(self, source, target):
"""Soft update of target network"""
for source_param, target_param in zip(source.parameters(), target.parameters()):
target_param.data.copy_(
self.tau * source_param.data + (1 - self.tau) * target_param.data
)
SAC
Soft Actor-Critic with entropy maximization.
class SACAgent:
def __init__(self, state_dim, action_dim, max_action, lr=0.0003,
gamma=0.99, tau=0.005, alpha=0.2):
self.gamma = gamma
self.tau = tau
self.alpha = alpha # Temperature parameter
self.max_action = max_action
# Actor (stochastic policy)
self.actor = Actor(state_dim, action_dim, max_action)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
# Two Q-functions (critics)
self.critic_1 = Critic(state_dim, action_dim)
self.critic_2 = Critic(state_dim, action_dim)
self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(), lr=lr)
self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(), lr=lr)
# Target critics
self.critic_1_target = Critic(state_dim, action_dim)
self.critic_2_target = Critic(state_dim, action_dim)
self.critic_1_target.load_state_dict(self.critic_1.state_dict())
self.critic_2_target.load_state_dict(self.critic_2.state_dict())
self.replay_buffer = ReplayBuffer()
def train_step(self, batch_size=256):
"""SAC training step"""
if len(self.replay_buffer) < batch_size:
return None
states, actions, rewards, next_states, dones = self.replay_buffer.sample(batch_size)
# Update critics
with torch.no_grad():
next_actions, next_log_probs = self.actor.sample(next_states)
target_q1 = self.critic_1_target(next_states, next_actions)
target_q2 = self.critic_2_target(next_states, next_actions)
target_q = torch.min(target_q1, target_q2) - self.alpha * next_log_probs
target_q = rewards.unsqueeze(1) + (1 - dones.unsqueeze(1)) * self.gamma * target_q
current_q1 = self.critic_1(states, actions)
current_q2 = self.critic_2(states, actions)
critic_1_loss = F.mse_loss(current_q1, target_q)
critic_2_loss = F.mse_loss(current_q2, target_q)
self.critic_1_optimizer.zero_grad()
critic_1_loss.backward()
self.critic_1_optimizer.step()
self.critic_2_optimizer.zero_grad()
critic_2_loss.backward()
self.critic_2_optimizer.step()
# Update actor
new_actions, log_probs = self.actor.sample(states)
q1 = self.critic_1(states, new_actions)
q2 = self.critic_2(states, new_actions)
q = torch.min(q1, q2)
actor_loss = (self.alpha * log_probs - q).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# Soft update target networks
self._soft_update(self.critic_1, self.critic_1_target)
self._soft_update(self.critic_2, self.critic_2_target)
return actor_loss.item(), critic_1_loss.item(), critic_2_loss.item()
TD3
Twin Delayed Deep Deterministic Policy Gradient.
class TD3Agent(DDPGAgent):
def __init__(self, state_dim, action_dim, max_action, lr=0.001,
gamma=0.99, tau=0.005, policy_noise=0.2,
noise_clip=0.5, policy_delay=2):
super().__init__(state_dim, action_dim, max_action, lr, gamma, tau)
# Twin critics
self.critic_2 = Critic(state_dim, action_dim)
self.critic_2_target = Critic(state_dim, action_dim)
self.critic_2_target.load_state_dict(self.critic_2.state_dict())
self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(), lr=lr)
self.policy_noise = policy_noise
self.noise_clip = noise_clip
self.policy_delay = policy_delay
self.total_iterations = 0
def train_step(self, batch_size=64):
"""TD3 training step"""
self.total_iterations += 1
if len(self.replay_buffer) < batch_size:
return None, None
states, actions, rewards, next_states, dones = self.replay_buffer.sample(batch_size)
# Update critics
with torch.no_grad():
# Target policy smoothing
noise = torch.randn_like(actions) * self.policy_noise
noise = torch.clamp(noise, -self.noise_clip, self.noise_clip)
next_actions = self.actor_target(next_states)
next_actions = torch.clamp(next_actions + noise, -self.max_action, self.max_action)
# Twin Q targets
target_q1 = self.critic_target(next_states, next_actions)
target_q2 = self.critic_2_target(next_states, next_actions)
target_q = torch.min(target_q1, target_q2)
target_q = rewards.unsqueeze(1) + (1 - dones.unsqueeze(1)) * self.gamma * target_q
current_q1 = self.critic(states, actions)
current_q2 = self.critic_2(states, actions)
critic_1_loss = F.mse_loss(current_q1, target_q)
critic_2_loss = F.mse_loss(current_q2, target_q)
self.critic_optimizer.zero_grad()
critic_1_loss.backward()
self.critic_optimizer.step()
self.critic_2_optimizer.zero_grad()
critic_2_loss.backward()
self.critic_2_optimizer.step()
# Delayed policy updates
actor_loss = None
if self.total_iterations % self.policy_delay == 0:
# Update actor
actor_loss = -self.critic(states, self.actor(states)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# Soft update target networks
self._soft_update(self.actor, self.actor_target)
self._soft_update(self.critic, self.critic_target)
self._soft_update(self.critic_2, self.critic_2_target)
return actor_loss.item() if actor_loss else None, critic_1_loss.item()
Practical Tips
- Hyperparameter Tuning: Learning rates, discount factors crucial
- Reward Scaling: Normalize rewards for stable training
- Network Architecture: Start simple, increase complexity as needed
- Exploration: Balance exploration vs exploitation
- Curriculum Learning: Start with easier tasks
- Distributed Training: Use parallel environments for faster learning
Resources
- OpenAI Spinning Up: https://spinningup.openai.com/
- Stable Baselines3: https://stable-baselines3.readthedocs.io/
- "Deep Reinforcement Learning Hands-On" by Maxim Lapan
- DeepMind papers: https://www.deepmind.com/research
Generative Models
Generative models learn to create new data samples that resemble the training data distribution.
Table of Contents
- Introduction
- Generative Adversarial Networks (GANs)
- Variational Autoencoders (VAEs)
- Normalizing Flows
- Autoregressive Models
- Energy-Based Models
- Diffusion Models
Introduction
Types of Generative Models:
- Explicit Density: Models that define explicit probability distribution (VAE, Flow models)
- Implicit Density: Models that can sample without explicit density (GANs)
- Tractable: Can compute exact likelihoods (Autoregressive, Flow models)
- Approximate: Use approximate inference (VAEs)
Generative Adversarial Networks
GANs use two networks competing against each other: Generator and Discriminator.
Basic GAN
Objective Function:
min_G max_D V(D,G) = E_x[log D(x)] + E_z[log(1 - D(G(z)))]
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# Generator Network
class Generator(nn.Module):
def __init__(self, latent_dim=100, img_shape=(1, 28, 28)):
super(Generator, self).__init__()
self.img_shape = img_shape
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat))
layers.append(nn.LeakyReLU(0.2))
return layers
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_shape)
return img
# Discriminator Network
class Discriminator(nn.Module):
def __init__(self, img_shape=(1, 28, 28)):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
# Training GAN
class GANTrainer:
def __init__(self, generator, discriminator, latent_dim=100,
lr=0.0002, betas=(0.5, 0.999)):
self.generator = generator
self.discriminator = discriminator
self.latent_dim = latent_dim
# Optimizers
self.optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=betas)
self.optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=betas)
# Loss function
self.adversarial_loss = nn.BCELoss()
def train_step(self, real_imgs):
batch_size = real_imgs.size(0)
# Adversarial ground truths
valid = torch.ones(batch_size, 1)
fake = torch.zeros(batch_size, 1)
# ---------------------
# Train Discriminator
# ---------------------
self.optimizer_D.zero_grad()
# Loss for real images
real_loss = self.adversarial_loss(self.discriminator(real_imgs), valid)
# Loss for fake images
z = torch.randn(batch_size, self.latent_dim)
fake_imgs = self.generator(z)
fake_loss = self.adversarial_loss(self.discriminator(fake_imgs.detach()), fake)
# Total discriminator loss
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
self.optimizer_D.step()
# -----------------
# Train Generator
# -----------------
self.optimizer_G.zero_grad()
# Generate fake images
z = torch.randn(batch_size, self.latent_dim)
gen_imgs = self.generator(z)
# Generator loss (fool discriminator)
g_loss = self.adversarial_loss(self.discriminator(gen_imgs), valid)
g_loss.backward()
self.optimizer_G.step()
return d_loss.item(), g_loss.item()
def train(self, dataloader, num_epochs=100):
"""Train GAN"""
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(dataloader):
d_loss, g_loss = self.train_step(imgs)
if i % 100 == 0:
print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}] "
f"[D loss: {d_loss:.4f}] [G loss: {g_loss:.4f}]")
# Sample images
if epoch % 10 == 0:
self.sample_images(epoch)
def sample_images(self, epoch, n_row=10):
"""Generate and save sample images"""
z = torch.randn(n_row**2, self.latent_dim)
gen_imgs = self.generator(z)
import torchvision.utils as vutils
vutils.save_image(gen_imgs.data, f"images/epoch_{epoch}.png",
nrow=n_row, normalize=True)
# Example usage
img_shape = (1, 28, 28)
generator = Generator(latent_dim=100, img_shape=img_shape)
discriminator = Discriminator(img_shape=img_shape)
trainer = GANTrainer(generator, discriminator)
# trainer.train(dataloader, num_epochs=100)
Deep Convolutional GAN (DCGAN)
class DCGANGenerator(nn.Module):
def __init__(self, latent_dim=100, channels=3):
super(DCGANGenerator, self).__init__()
self.init_size = 4
self.l1 = nn.Linear(latent_dim, 128 * self.init_size ** 2)
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2), # 4x4 -> 8x8
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.Upsample(scale_factor=2), # 8x8 -> 16x16
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2),
nn.Upsample(scale_factor=2), # 16x16 -> 32x32
nn.Conv2d(64, channels, 3, stride=1, padding=1),
nn.Tanh()
)
def forward(self, z):
out = self.l1(z)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
class DCGANDiscriminator(nn.Module):
def __init__(self, channels=3):
super(DCGANDiscriminator, self).__init__()
def discriminator_block(in_filters, out_filters, bn=True):
block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1),
nn.LeakyReLU(0.2),
nn.Dropout2d(0.25)]
if bn:
block.append(nn.BatchNorm2d(out_filters))
return block
self.model = nn.Sequential(
*discriminator_block(channels, 16, bn=False), # 32x32 -> 16x16
*discriminator_block(16, 32), # 16x16 -> 8x8
*discriminator_block(32, 64), # 8x8 -> 4x4
*discriminator_block(64, 128), # 4x4 -> 2x2
)
# Output layer
ds_size = 2
self.adv_layer = nn.Sequential(
nn.Linear(128 * ds_size ** 2, 1),
nn.Sigmoid()
)
def forward(self, img):
out = self.model(img)
out = out.view(out.shape[0], -1)
validity = self.adv_layer(out)
return validity
Conditional GAN (cGAN)
class ConditionalGenerator(nn.Module):
def __init__(self, latent_dim=100, n_classes=10, img_shape=(1, 28, 28)):
super(ConditionalGenerator, self).__init__()
self.img_shape = img_shape
self.label_emb = nn.Embedding(n_classes, n_classes)
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat))
layers.append(nn.LeakyReLU(0.2))
return layers
self.model = nn.Sequential(
*block(latent_dim + n_classes, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, noise, labels):
# Concatenate label embedding and noise
gen_input = torch.cat((self.label_emb(labels), noise), -1)
img = self.model(gen_input)
img = img.view(img.size(0), *self.img_shape)
return img
class ConditionalDiscriminator(nn.Module):
def __init__(self, n_classes=10, img_shape=(1, 28, 28)):
super(ConditionalDiscriminator, self).__init__()
self.label_embedding = nn.Embedding(n_classes, n_classes)
self.model = nn.Sequential(
nn.Linear(n_classes + int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img, labels):
# Concatenate label embedding and image
d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
validity = self.model(d_in)
return validity
Wasserstein GAN (WGAN)
class WGANTrainer:
def __init__(self, generator, discriminator, latent_dim=100,
lr=0.00005, n_critic=5, clip_value=0.01):
self.generator = generator
self.discriminator = discriminator
self.latent_dim = latent_dim
self.n_critic = n_critic
self.clip_value = clip_value
# RMSprop optimizers
self.optimizer_G = optim.RMSprop(generator.parameters(), lr=lr)
self.optimizer_D = optim.RMSprop(discriminator.parameters(), lr=lr)
def train_step(self, real_imgs):
batch_size = real_imgs.size(0)
# ---------------------
# Train Discriminator
# ---------------------
self.optimizer_D.zero_grad()
# Sample noise
z = torch.randn(batch_size, self.latent_dim)
fake_imgs = self.generator(z).detach()
# Wasserstein loss
loss_D = -torch.mean(self.discriminator(real_imgs)) + \
torch.mean(self.discriminator(fake_imgs))
loss_D.backward()
self.optimizer_D.step()
# Clip weights
for p in self.discriminator.parameters():
p.data.clamp_(-self.clip_value, self.clip_value)
# Train generator every n_critic iterations
if self.n_critic > 0:
self.n_critic -= 1
return loss_D.item(), None
# -----------------
# Train Generator
# -----------------
self.optimizer_G.zero_grad()
z = torch.randn(batch_size, self.latent_dim)
gen_imgs = self.generator(z)
# Generator loss
loss_G = -torch.mean(self.discriminator(gen_imgs))
loss_G.backward()
self.optimizer_G.step()
self.n_critic = 5 # Reset
return loss_D.item(), loss_G.item()
StyleGAN Concepts
class StyleGANGenerator(nn.Module):
"""Simplified StyleGAN architecture"""
def __init__(self, latent_dim=512, style_dim=512, n_mlp=8):
super(StyleGANGenerator, self).__init__()
# Mapping network (converts z to w)
layers = []
for i in range(n_mlp):
layers.append(nn.Linear(latent_dim if i == 0 else style_dim, style_dim))
layers.append(nn.LeakyReLU(0.2))
self.mapping = nn.Sequential(*layers)
# Synthesis network (generates image from w)
self.const_input = nn.Parameter(torch.randn(1, 512, 4, 4))
# Progressive layers with AdaIN
self.prog_blocks = nn.ModuleList()
self.style_blocks = nn.ModuleList()
channels = [512, 512, 512, 256, 128, 64, 32]
for i in range(len(channels) - 1):
self.prog_blocks.append(
nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(channels[i], channels[i+1], 3, padding=1),
nn.LeakyReLU(0.2)
)
)
self.style_blocks.append(
nn.Linear(style_dim, channels[i+1] * 2) # For AdaIN
)
self.to_rgb = nn.Conv2d(channels[-1], 3, 1)
def forward(self, z):
# Map to style space
w = self.mapping(z)
# Start with constant
x = self.const_input.repeat(z.size(0), 1, 1, 1)
# Apply progressive blocks with style modulation
for prog_block, style_block in zip(self.prog_blocks, self.style_blocks):
x = prog_block(x)
# AdaIN (Adaptive Instance Normalization)
style = style_block(w).unsqueeze(2).unsqueeze(3)
style_mean, style_std = style.chunk(2, 1)
x = F.instance_norm(x)
x = x * (style_std + 1) + style_mean
# Convert to RGB
img = self.to_rgb(x)
return torch.tanh(img)
Variational Autoencoders
VAEs learn a latent representation by maximizing a variational lower bound on the data likelihood.
Objective (ELBO):
log p(x) ≥ E_q[log p(x|z)] - KL(q(z|x) || p(z))
Basic VAE
class VAE(nn.Module):
def __init__(self, input_dim=784, latent_dim=20, hidden_dim=400):
super(VAE, self).__init__()
# Encoder
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
# Decoder
self.fc3 = nn.Linear(latent_dim, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, input_dim)
def encode(self, x):
"""Encode input to latent distribution parameters"""
h = F.relu(self.fc1(x))
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
def reparameterize(self, mu, logvar):
"""Reparameterization trick"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
"""Decode latent to output"""
h = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h))
def forward(self, x):
mu, logvar = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, logvar)
recon = self.decode(z)
return recon, mu, logvar
def vae_loss(recon_x, x, mu, logvar):
"""VAE loss function"""
# Reconstruction loss (binary cross-entropy)
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
# KL divergence
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
# Training
model = VAE()
optimizer = optim.Adam(model.parameters(), lr=0.001)
def train_vae(model, dataloader, num_epochs=10):
model.train()
for epoch in range(num_epochs):
train_loss = 0
for batch_idx, (data, _) in enumerate(dataloader):
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = vae_loss(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch: {epoch} [{batch_idx * len(data)}/{len(dataloader.dataset)}] '
f'Loss: {loss.item() / len(data):.4f}')
print(f'Epoch {epoch} Average loss: {train_loss / len(dataloader.dataset):.4f}')
Convolutional VAE
class ConvVAE(nn.Module):
def __init__(self, latent_dim=128, channels=3):
super(ConvVAE, self).__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(channels, 32, 4, 2, 1), # 32x32 -> 16x16
nn.ReLU(),
nn.Conv2d(32, 64, 4, 2, 1), # 16x16 -> 8x8
nn.ReLU(),
nn.Conv2d(64, 128, 4, 2, 1), # 8x8 -> 4x4
nn.ReLU(),
nn.Conv2d(128, 256, 4, 2, 1), # 4x4 -> 2x2
nn.ReLU()
)
self.fc_mu = nn.Linear(256 * 2 * 2, latent_dim)
self.fc_logvar = nn.Linear(256 * 2 * 2, latent_dim)
# Decoder
self.fc_decode = nn.Linear(latent_dim, 256 * 2 * 2)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, 2, 1), # 2x2 -> 4x4
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, 2, 1), # 4x4 -> 8x8
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, 2, 1), # 8x8 -> 16x16
nn.ReLU(),
nn.ConvTranspose2d(32, channels, 4, 2, 1), # 16x16 -> 32x32
nn.Sigmoid()
)
def encode(self, x):
h = self.encoder(x)
h = h.view(h.size(0), -1)
return self.fc_mu(h), self.fc_logvar(h)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
h = self.fc_decode(z)
h = h.view(h.size(0), 256, 2, 2)
return self.decoder(h)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
Beta-VAE
def beta_vae_loss(recon_x, x, mu, logvar, beta=4.0):
"""Beta-VAE loss with adjustable KL weight"""
# Reconstruction loss
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
# KL divergence with beta weight
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + beta * KLD
Normalizing Flows
Flow models use invertible transformations to model complex distributions.
Simple Flow
class CouplingLayer(nn.Module):
"""Affine coupling layer"""
def __init__(self, dim, hidden_dim=256):
super(CouplingLayer, self).__init__()
self.dim = dim
self.split = dim // 2
# Scale and translate networks
self.scale_net = nn.Sequential(
nn.Linear(self.split, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, dim - self.split),
nn.Tanh()
)
self.translate_net = nn.Sequential(
nn.Linear(self.split, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, dim - self.split)
)
def forward(self, x, reverse=False):
x1, x2 = x[:, :self.split], x[:, self.split:]
if not reverse:
# Forward pass
s = self.scale_net(x1)
t = self.translate_net(x1)
y2 = x2 * torch.exp(s) + t
y = torch.cat([x1, y2], dim=1)
log_det = torch.sum(s, dim=1)
else:
# Inverse pass
s = self.scale_net(x1)
t = self.translate_net(x1)
y2 = (x2 - t) * torch.exp(-s)
y = torch.cat([x1, y2], dim=1)
log_det = -torch.sum(s, dim=1)
return y, log_det
class NormalizingFlow(nn.Module):
def __init__(self, dim, num_layers=8):
super(NormalizingFlow, self).__init__()
self.layers = nn.ModuleList([
CouplingLayer(dim) for _ in range(num_layers)
])
def forward(self, x, reverse=False):
log_det_sum = 0
layers = reversed(self.layers) if reverse else self.layers
for layer in layers:
x, log_det = layer(x, reverse=reverse)
log_det_sum += log_det
return x, log_det_sum
def log_prob(self, x):
"""Compute log probability"""
z, log_det = self.forward(x, reverse=False)
# Base distribution (standard normal)
log_prob_z = -0.5 * (z ** 2 + np.log(2 * np.pi)).sum(dim=1)
return log_prob_z + log_det
Autoregressive Models
Generate data sequentially, one element at a time.
PixelCNN
class MaskedConv2d(nn.Conv2d):
"""Masked convolution for autoregressive generation"""
def __init__(self, mask_type, *args, **kwargs):
super(MaskedConv2d, self).__init__(*args, **kwargs)
self.register_buffer('mask', torch.zeros_like(self.weight))
self.mask[:, :, :self.kernel_size[0] // 2] = 1
self.mask[:, :, self.kernel_size[0] // 2, :self.kernel_size[1] // 2] = 1
if mask_type == 'A':
# Mask type A: exclude center pixel
self.mask[:, :, self.kernel_size[0] // 2, self.kernel_size[1] // 2] = 0
def forward(self, x):
self.weight.data *= self.mask
return super(MaskedConv2d, self).forward(x)
class PixelCNN(nn.Module):
def __init__(self, n_channels=1, n_filters=64, n_layers=7):
super(PixelCNN, self).__init__()
self.layers = nn.ModuleList()
# First layer (mask type A)
self.layers.append(
nn.Sequential(
MaskedConv2d('A', n_channels, n_filters, 7, padding=3),
nn.BatchNorm2d(n_filters),
nn.ReLU()
)
)
# Hidden layers (mask type B)
for _ in range(n_layers):
self.layers.append(
nn.Sequential(
MaskedConv2d('B', n_filters, n_filters, 7, padding=3),
nn.BatchNorm2d(n_filters),
nn.ReLU()
)
)
# Output layer
self.output = nn.Conv2d(n_filters, n_channels * 256, 1)
def forward(self, x):
for layer in self.layers:
x = layer(x)
x = self.output(x)
# Reshape for pixel-wise softmax
b, _, h, w = x.size()
x = x.view(b, 256, -1, h, w)
return x
Energy-Based Models
Model probability as energy function: p(x) ∝ exp(-E(x))
class EnergyBasedModel(nn.Module):
def __init__(self, input_dim):
super(EnergyBasedModel, self).__init__()
self.energy_net = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 1)
)
def energy(self, x):
"""Compute energy E(x)"""
return self.energy_net(x)
def sample_langevin(self, x, n_steps=100, step_size=0.01):
"""Sample using Langevin dynamics"""
x = x.clone().detach().requires_grad_(True)
for _ in range(n_steps):
energy = self.energy(x).sum()
grad = torch.autograd.grad(energy, x)[0]
noise = torch.randn_like(x) * np.sqrt(step_size * 2)
x = x - step_size * grad + noise
return x.detach()
Diffusion Models
Gradually add noise then learn to denoise.
DDPM (Denoising Diffusion Probabilistic Models)
class DiffusionModel(nn.Module):
def __init__(self, timesteps=1000):
super(DiffusionModel, self).__init__()
self.timesteps = timesteps
# Linear beta schedule
self.betas = torch.linspace(0.0001, 0.02, timesteps)
self.alphas = 1 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# Noise prediction network (U-Net)
self.noise_predictor = self._build_unet()
def _build_unet(self):
"""Simple U-Net for noise prediction"""
# Simplified version
return nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 3, 3, padding=1)
)
def q_sample(self, x0, t, noise=None):
"""Forward diffusion: add noise to x0"""
if noise is None:
noise = torch.randn_like(x0)
sqrt_alphas_cumprod_t = self.alphas_cumprod[t].sqrt()
sqrt_one_minus_alphas_cumprod_t = (1 - self.alphas_cumprod[t]).sqrt()
return sqrt_alphas_cumprod_t * x0 + sqrt_one_minus_alphas_cumprod_t * noise
def p_sample(self, xt, t):
"""Reverse diffusion: denoise xt"""
# Predict noise
predicted_noise = self.noise_predictor(xt)
# Compute x_{t-1}
alpha_t = self.alphas[t]
alpha_cumprod_t = self.alphas_cumprod[t]
beta_t = self.betas[t]
x0_pred = (xt - ((1 - alpha_t) / (1 - alpha_cumprod_t).sqrt()) * predicted_noise) / alpha_t.sqrt()
if t > 0:
noise = torch.randn_like(xt)
x_prev = x0_pred * alpha_t.sqrt() + (1 - alpha_t).sqrt() * noise
else:
x_prev = x0_pred
return x_prev
def sample(self, shape):
"""Generate samples"""
device = next(self.parameters()).device
# Start from random noise
x = torch.randn(shape).to(device)
# Iteratively denoise
for t in reversed(range(self.timesteps)):
x = self.p_sample(x, t)
return x
Evaluation Metrics
# Inception Score (IS)
def inception_score(imgs, splits=10):
"""Higher is better"""
from torchvision.models import inception_v3
inception_model = inception_v3(pretrained=True, transform_input=False)
inception_model.eval()
# Get predictions
with torch.no_grad():
preds = inception_model(imgs)
preds = F.softmax(preds, dim=1)
# Compute IS
split_scores = []
for k in range(splits):
part = preds[k * (len(preds) // splits): (k + 1) * (len(preds) // splits)]
py = part.mean(dim=0)
scores = []
for i in range(part.shape[0]):
pyx = part[i]
scores.append((pyx * (torch.log(pyx) - torch.log(py))).sum())
split_scores.append(torch.exp(torch.mean(torch.stack(scores))))
return torch.mean(torch.stack(split_scores)), torch.std(torch.stack(split_scores))
# Fréchet Inception Distance (FID)
def calculate_fid(real_imgs, fake_imgs):
"""Lower is better"""
# Extract features using Inception network
# Calculate mean and covariance
# Compute FID score
pass
Practical Tips
- GAN Training: Balance G and D, use label smoothing, add noise to inputs
- VAE: Choose appropriate beta value, use warm-up for KL term
- Stability: Monitor losses, use spectral normalization
- Architecture: Start simple, gradually add complexity
- Evaluation: Use multiple metrics (IS, FID, visual inspection)
Resources
- "Generative Deep Learning" by David Foster
- OpenAI papers: https://openai.com/research/
- Distill.pub: https://distill.pub/
- Papers with Code: https://paperswithcode.com/
Deep Generative Models
Advanced architectures and techniques for generating high-quality data.
Table of Contents
- Transformer-based Generative Models
- Diffusion Models
- Vector Quantized Models
- NeRF and 3D Generation
- Multimodal Generative Models
Transformer-based Generative Models
GPT (Generative Pre-trained Transformer)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# Linear projections
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
context = torch.matmul(attn_weights, V)
# Concatenate heads
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# Final linear projection
output = self.W_o(context)
return output, attn_weights
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super(FeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.linear2(self.dropout(F.gelu(self.linear1(x))))
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(TransformerBlock, self).__init__()
self.attention = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Multi-head attention with residual connection
attn_output, _ = self.attention(x, x, x, mask)
x = self.norm1(x + self.dropout1(attn_output))
# Feed-forward with residual connection
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout2(ff_output))
return x
class GPTModel(nn.Module):
def __init__(self, vocab_size, d_model=768, num_heads=12, num_layers=12,
d_ff=3072, max_seq_length=1024, dropout=0.1):
super(GPTModel, self).__init__()
self.d_model = d_model
self.max_seq_length = max_seq_length
# Token and position embeddings
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = nn.Embedding(max_seq_length, d_model)
# Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
# Output projection
self.ln_f = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, input_ids, targets=None):
batch_size, seq_length = input_ids.size()
# Create position ids
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
# Embeddings
token_embeds = self.token_embedding(input_ids)
position_embeds = self.position_embedding(position_ids)
x = self.dropout(token_embeds + position_embeds)
# Causal mask
mask = torch.tril(torch.ones(seq_length, seq_length, device=input_ids.device))
mask = mask.view(1, 1, seq_length, seq_length)
# Transformer blocks
for block in self.blocks:
x = block(x, mask)
# Output
x = self.ln_f(x)
logits = self.head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
def generate(self, input_ids, max_new_tokens=50, temperature=1.0, top_k=None):
"""Generate text autoregressively"""
for _ in range(max_new_tokens):
# Crop context if needed
idx_cond = input_ids if input_ids.size(1) <= self.max_seq_length else input_ids[:, -self.max_seq_length:]
# Forward pass
logits, _ = self.forward(idx_cond)
logits = logits[:, -1, :] / temperature
# Top-k sampling
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# Sample
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Append
input_ids = torch.cat([input_ids, next_token], dim=1)
return input_ids
# Training example
model = GPTModel(vocab_size=50257, d_model=768, num_heads=12, num_layers=12)
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4, betas=(0.9, 0.95))
def train_step(input_ids, targets):
optimizer.zero_grad()
logits, loss = model(input_ids, targets)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
return loss.item()
Vision Transformer for Generation (ViT-VQGAN)
class VisionTransformerGenerator(nn.Module):
def __init__(self, img_size=256, patch_size=16, embed_dim=768, num_heads=12, depth=12):
super(VisionTransformerGenerator, self).__init__()
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# Patch embedding
self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
# Position embedding
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
# Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, embed_dim * 4)
for _ in range(depth)
])
# Decoder
self.decoder = nn.Sequential(
nn.ConvTranspose2d(embed_dim, 512, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(512, 256, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(64, 3, 4, 2, 1),
nn.Tanh()
)
def forward(self, x):
# Patch embedding
x = self.patch_embed(x)
x = x.flatten(2).transpose(1, 2)
# Add position embedding
x = x + self.pos_embed
# Transformer blocks
for block in self.blocks:
x = block(x)
# Reshape for decoder
b, n, c = x.shape
h = w = int(math.sqrt(n))
x = x.transpose(1, 2).reshape(b, c, h, w)
# Decode
x = self.decoder(x)
return x
Diffusion Models
Improved DDPM
class ImprovedDDPM(nn.Module):
def __init__(self, img_channels=3, base_channels=128, timesteps=1000):
super(ImprovedDDPM, self).__init__()
self.timesteps = timesteps
# Variance schedule (cosine)
self.betas = self._cosine_beta_schedule(timesteps)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
# U-Net architecture
self.time_embed = nn.Sequential(
nn.Linear(base_channels, base_channels * 4),
nn.SiLU(),
nn.Linear(base_channels * 4, base_channels * 4)
)
# Encoder
self.down1 = self._make_down_block(img_channels, base_channels)
self.down2 = self._make_down_block(base_channels, base_channels * 2)
self.down3 = self._make_down_block(base_channels * 2, base_channels * 4)
# Bottleneck
self.mid = self._make_res_block(base_channels * 4)
# Decoder
self.up3 = self._make_up_block(base_channels * 4, base_channels * 2)
self.up2 = self._make_up_block(base_channels * 2, base_channels)
self.up1 = self._make_up_block(base_channels, img_channels)
def _cosine_beta_schedule(self, timesteps, s=0.008):
"""Cosine schedule for betas"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
def _make_down_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.GroupNorm(8, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.GroupNorm(8, out_channels),
nn.SiLU(),
nn.MaxPool2d(2)
)
def _make_up_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.GroupNorm(8, out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.GroupNorm(8, out_channels),
nn.SiLU()
)
def _make_res_block(self, channels):
return nn.Sequential(
nn.Conv2d(channels, channels, 3, padding=1),
nn.GroupNorm(8, channels),
nn.SiLU(),
nn.Conv2d(channels, channels, 3, padding=1),
nn.GroupNorm(8, channels),
nn.SiLU()
)
def forward(self, x, t):
"""Predict noise"""
# Time embedding
t_emb = self._get_timestep_embedding(t, x.device)
t_emb = self.time_embed(t_emb)
# U-Net forward
h1 = self.down1(x)
h2 = self.down2(h1)
h3 = self.down3(h2)
h = self.mid(h3)
h = self.up3(h + h3)
h = self.up2(h + h2)
h = self.up1(h + h1)
return h
def _get_timestep_embedding(self, timesteps, device, dim=128):
"""Sinusoidal positional embedding"""
half_dim = dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = timesteps[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
return emb
@torch.no_grad()
def sample(self, batch_size, img_size, device):
"""DDPM sampling"""
# Start from random noise
img = torch.randn(batch_size, 3, img_size, img_size, device=device)
for t in reversed(range(self.timesteps)):
t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
# Predict noise
predicted_noise = self.forward(img, t_batch)
# Compute x_{t-1}
alpha_t = self.alphas[t]
alpha_cumprod_t = self.alphas_cumprod[t]
beta_t = self.betas[t]
if t > 0:
noise = torch.randn_like(img)
else:
noise = torch.zeros_like(img)
img = (1 / alpha_t.sqrt()) * (img - ((1 - alpha_t) / (1 - alpha_cumprod_t).sqrt()) * predicted_noise)
img = img + beta_t.sqrt() * noise
return img
Latent Diffusion Models (Stable Diffusion)
class LatentDiffusion(nn.Module):
def __init__(self, vae, unet, text_encoder):
super(LatentDiffusion, self).__init__()
self.vae = vae # VAE for encoding/decoding images
self.unet = unet # U-Net for denoising in latent space
self.text_encoder = text_encoder # CLIP text encoder
self.timesteps = 1000
self.betas = torch.linspace(0.0001, 0.02, self.timesteps)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
def forward(self, images, text_embeddings, t):
"""Training forward pass"""
# Encode images to latent space
with torch.no_grad():
latents = self.vae.encode(images)
# Add noise
noise = torch.randn_like(latents)
noisy_latents = self._add_noise(latents, noise, t)
# Predict noise conditioned on text
predicted_noise = self.unet(noisy_latents, t, text_embeddings)
# Loss
loss = F.mse_loss(predicted_noise, noise)
return loss
def _add_noise(self, latents, noise, t):
"""Add noise according to schedule"""
sqrt_alpha_prod = self.alphas_cumprod[t] ** 0.5
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[t]) ** 0.5
return sqrt_alpha_prod * latents + sqrt_one_minus_alpha_prod * noise
@torch.no_grad()
def generate(self, text, batch_size=1, guidance_scale=7.5):
"""Text-to-image generation"""
# Encode text
text_embeddings = self.text_encoder(text)
# Start from random noise in latent space
latents = torch.randn(batch_size, 4, 64, 64)
# Denoising loop
for t in reversed(range(self.timesteps)):
t_batch = torch.full((batch_size,), t)
# Predict noise with and without conditioning (classifier-free guidance)
noise_pred_text = self.unet(latents, t_batch, text_embeddings)
noise_pred_uncond = self.unet(latents, t_batch, None)
# Apply guidance
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Update latents
latents = self._denoise_step(latents, noise_pred, t)
# Decode latents to images
images = self.vae.decode(latents)
return images
def _denoise_step(self, latents, noise, t):
"""Single denoising step"""
alpha_t = self.alphas[t]
alpha_cumprod_t = self.alphas_cumprod[t]
beta_t = self.betas[t]
pred_original = (latents - ((1 - alpha_t) / (1 - alpha_cumprod_t).sqrt()) * noise) / alpha_t.sqrt()
if t > 0:
noise = torch.randn_like(latents)
latents = pred_original * alpha_t.sqrt() + (1 - alpha_t).sqrt() * noise
else:
latents = pred_original
return latents
Vector Quantized Models
VQ-VAE
class VectorQuantizer(nn.Module):
def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
super(VectorQuantizer, self).__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.commitment_cost = commitment_cost
self.embeddings = nn.Embedding(num_embeddings, embedding_dim)
self.embeddings.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)
def forward(self, inputs):
# Flatten input
flat_input = inputs.view(-1, self.embedding_dim)
# Calculate distances
distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
+ torch.sum(self.embeddings.weight**2, dim=1)
- 2 * torch.matmul(flat_input, self.embeddings.weight.t()))
# Get closest embeddings
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=inputs.device)
encodings.scatter_(1, encoding_indices, 1)
# Quantize
quantized = torch.matmul(encodings, self.embeddings.weight)
quantized = quantized.view_as(inputs)
# Loss
e_latent_loss = F.mse_loss(quantized.detach(), inputs)
q_latent_loss = F.mse_loss(quantized, inputs.detach())
loss = q_latent_loss + self.commitment_cost * e_latent_loss
# Straight-through estimator
quantized = inputs + (quantized - inputs).detach()
return quantized, loss, encoding_indices
class VQVAE(nn.Module):
def __init__(self, num_embeddings=512, embedding_dim=64):
super(VQVAE, self).__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1),
nn.ReLU(),
nn.Conv2d(64, 128, 4, 2, 1),
nn.ReLU(),
nn.Conv2d(128, embedding_dim, 3, 1, 1)
)
# Vector quantizer
self.vq = VectorQuantizer(num_embeddings, embedding_dim)
# Decoder
self.decoder = nn.Sequential(
nn.ConvTranspose2d(embedding_dim, 128, 3, 1, 1),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.ReLU(),
nn.ConvTranspose2d(64, 3, 4, 2, 1),
nn.Tanh()
)
def forward(self, x):
z = self.encoder(x)
quantized, vq_loss, _ = self.vq(z)
recon = self.decoder(quantized)
recon_loss = F.mse_loss(recon, x)
return recon, recon_loss + vq_loss
NeRF and 3D Generation
Neural Radiance Fields
class NeRF(nn.Module):
def __init__(self, pos_dim=3, dir_dim=3, hidden_dim=256):
super(NeRF, self).__init__()
# Position encoding
self.pos_encoder = self._positional_encoding
self.dir_encoder = self._positional_encoding
# MLP for density and features
self.density_net = nn.Sequential(
nn.Linear(pos_dim * 2 * 10, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim + 1) # density + features
)
# MLP for color
self.color_net = nn.Sequential(
nn.Linear(hidden_dim + dir_dim * 2 * 4, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 3),
nn.Sigmoid()
)
def _positional_encoding(self, x, L=10):
"""Positional encoding for coordinates"""
encoding = []
for l in range(L):
encoding.append(torch.sin(2**l * math.pi * x))
encoding.append(torch.cos(2**l * math.pi * x))
return torch.cat(encoding, dim=-1)
def forward(self, positions, directions):
# Encode positions and directions
pos_enc = self.pos_encoder(positions)
dir_enc = self.dir_encoder(directions)
# Get density and features
density_features = self.density_net(pos_enc)
density = F.relu(density_features[:, :1])
features = density_features[:, 1:]
# Get color
color_input = torch.cat([features, dir_enc], dim=-1)
color = self.color_net(color_input)
return density, color
def render_rays(self, ray_origins, ray_directions, near=2.0, far=6.0, n_samples=64):
"""Volume rendering along rays"""
# Sample points along rays
t_vals = torch.linspace(near, far, n_samples, device=ray_origins.device)
points = ray_origins[:, None, :] + ray_directions[:, None, :] * t_vals[None, :, None]
# Flatten for network
points_flat = points.reshape(-1, 3)
dirs_flat = ray_directions[:, None, :].expand_as(points).reshape(-1, 3)
# Get density and color
density, color = self.forward(points_flat, dirs_flat)
# Reshape
density = density.reshape(points.shape[0], n_samples)
color = color.reshape(points.shape[0], n_samples, 3)
# Volume rendering
dists = t_vals[1:] - t_vals[:-1]
dists = torch.cat([dists, torch.tensor([1e10], device=dists.device)])
alpha = 1.0 - torch.exp(-density * dists)
transmittance = torch.cumprod(1.0 - alpha + 1e-10, dim=-1)
transmittance = torch.cat([torch.ones_like(transmittance[:, :1]), transmittance[:, :-1]], dim=-1)
weights = alpha * transmittance
rgb = torch.sum(weights[:, :, None] * color, dim=1)
return rgb
Multimodal Generative Models
CLIP-guided Generation
class CLIPGuidedGenerator:
def __init__(self, generator, clip_model):
self.generator = generator
self.clip_model = clip_model
def generate(self, text_prompt, num_steps=100, lr=0.1):
"""Generate image guided by CLIP text embedding"""
# Encode text
with torch.no_grad():
text_features = self.clip_model.encode_text(text_prompt)
# Start with random latent
latent = torch.randn(1, 512, requires_grad=True)
optimizer = torch.optim.Adam([latent], lr=lr)
for step in range(num_steps):
optimizer.zero_grad()
# Generate image
image = self.generator(latent)
# Encode image with CLIP
image_features = self.clip_model.encode_image(image)
# CLIP loss (maximize similarity)
loss = -torch.cosine_similarity(text_features, image_features).mean()
loss.backward()
optimizer.step()
if step % 10 == 0:
print(f"Step {step}, Loss: {loss.item():.4f}")
# Generate final image
with torch.no_grad():
final_image = self.generator(latent)
return final_image
Practical Tips
- Model Size: Start small, scale up gradually
- Training Stability: Use gradient clipping, EMA
- Quality Metrics: FID, IS, LPIPS for evaluation
- Computational Efficiency: Use mixed precision, model parallelism
- Fine-tuning: Transfer from pre-trained models
Resources
- Stable Diffusion: https://github.com/CompVis/stable-diffusion
- DALL-E 2 paper: https://arxiv.org/abs/2204.06125
- Imagen paper: https://arxiv.org/abs/2205.11487
- NeRF: https://www.matthewtancik.com/nerf
Transfer Learning
Transfer learning leverages knowledge from pre-trained models to solve new tasks with limited data.
Table of Contents
- Introduction
- Pre-training Strategies
- Fine-tuning Techniques
- Domain Adaptation
- Few-Shot Learning
- Model Distillation
Introduction
Key Concepts:
- Pre-training: Training on large dataset for general features
- Fine-tuning: Adapting pre-trained model to specific task
- Feature Extraction: Using pre-trained model as fixed feature extractor
- Domain Shift: Difference between source and target distributions
When to Use Transfer Learning:
- Limited target data
- Similar source and target tasks
- Computational constraints
- Need for faster convergence
Pre-training Strategies
Self-Supervised Pre-training
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
# Contrastive Learning (SimCLR)
class SimCLR(nn.Module):
def __init__(self, base_encoder, projection_dim=128):
super(SimCLR, self).__init__()
self.encoder = base_encoder
# Remove classification head
self.encoder.fc = nn.Identity()
# Projection head
self.projection = nn.Sequential(
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, projection_dim)
)
def forward(self, x):
h = self.encoder(x)
z = self.projection(h)
return F.normalize(z, dim=1)
def contrastive_loss(z_i, z_j, temperature=0.5):
"""NT-Xent loss for contrastive learning"""
batch_size = z_i.shape[0]
# Concatenate representations
z = torch.cat([z_i, z_j], dim=0)
# Compute similarity matrix
similarity_matrix = torch.matmul(z, z.T) / temperature
# Create labels
labels = torch.arange(batch_size, device=z.device)
labels = torch.cat([labels + batch_size, labels])
# Mask out self-similarity
mask = torch.eye(2 * batch_size, device=z.device).bool()
similarity_matrix = similarity_matrix.masked_fill(mask, -float('inf'))
# Compute loss
loss = F.cross_entropy(similarity_matrix, labels)
return loss
# Training SimCLR
def train_simclr(model, train_loader, num_epochs=100):
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(num_epochs):
total_loss = 0
for (x1, x2), _ in train_loader: # x1, x2 are augmented views
optimizer.zero_grad()
# Get representations
z1 = model(x1)
z2 = model(x2)
# Compute loss
loss = contrastive_loss(z1, z2)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")
return model
# Create augmented pairs
class ContrastiveTransform:
def __init__(self):
self.transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def __call__(self, x):
return self.transform(x), self.transform(x)
Masked Language Modeling (BERT-style)
class MaskedLanguageModel(nn.Module):
def __init__(self, vocab_size, d_model=768, num_heads=12, num_layers=12):
super(MaskedLanguageModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.position_embedding = nn.Embedding(512, d_model)
encoder_layer = nn.TransformerEncoderLayer(d_model, num_heads, dim_feedforward=3072)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
self.lm_head = nn.Linear(d_model, vocab_size)
def forward(self, input_ids, attention_mask=None):
# Embeddings
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)
embeddings = self.embedding(input_ids) + self.position_embedding(position_ids)
# Transformer
hidden_states = self.transformer(embeddings.transpose(0, 1)).transpose(0, 1)
# Prediction
logits = self.lm_head(hidden_states)
return logits
def mask_tokens(inputs, tokenizer, mlm_probability=0.15):
"""Prepare masked tokens for MLM"""
labels = inputs.clone()
# Create random mask
probability_matrix = torch.full(labels.shape, mlm_probability)
masked_indices = torch.bernoulli(probability_matrix).bool()
# Only mask non-special tokens
special_tokens_mask = tokenizer.get_special_tokens_mask(
labels.tolist(), already_has_special_tokens=True
)
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
# Set labels for non-masked tokens to -100
labels[~masked_indices] = -100
# Replace masked tokens
# 80% [MASK], 10% random, 10% original
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = tokenizer.mask_token_id
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
return inputs, labels
# Training loop
def train_mlm(model, train_loader, tokenizer, num_epochs=10):
optimizer = optim.Adam(model.parameters(), lr=5e-5)
criterion = nn.CrossEntropyLoss(ignore_index=-100)
for epoch in range(num_epochs):
total_loss = 0
for batch in train_loader:
input_ids = batch['input_ids']
# Mask tokens
masked_inputs, labels = mask_tokens(input_ids, tokenizer)
# Forward pass
logits = model(masked_inputs)
# Compute loss
loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")
Fine-tuning Techniques
Standard Fine-tuning
def fine_tune_model(pretrained_model, train_loader, val_loader, num_classes, num_epochs=10):
"""Fine-tune pre-trained model on new task"""
# Load pre-trained model
model = models.resnet50(pretrained=True)
# Replace classification head
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, num_classes)
# Optimizer with different learning rates
params = [
{'params': model.fc.parameters(), 'lr': 1e-3}, # New layer: higher LR
{'params': [p for n, p in model.named_parameters() if 'fc' not in n],
'lr': 1e-4} # Pre-trained layers: lower LR
]
optimizer = optim.Adam(params)
criterion = nn.CrossEntropyLoss()
# Training loop
best_val_acc = 0
for epoch in range(num_epochs):
# Training
model.train()
train_loss = 0
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
# Validation
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
val_acc = 100 * correct / total
print(f"Epoch {epoch+1}, Train Loss: {train_loss/len(train_loader):.4f}, "
f"Val Acc: {val_acc:.2f}%")
# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), 'best_model.pth')
return model
Progressive Unfreezing
def progressive_unfreezing(model, train_loader, num_epochs=20, unfreeze_every=5):
"""Gradually unfreeze layers during fine-tuning"""
# Initially freeze all layers
for param in model.parameters():
param.requires_grad = False
# Only train classification head
for param in model.fc.parameters():
param.requires_grad = True
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
criterion = nn.CrossEntropyLoss()
# Get layer groups (from top to bottom)
layer_groups = [
model.fc,
model.layer4,
model.layer3,
model.layer2,
model.layer1
]
for epoch in range(num_epochs):
# Unfreeze next layer group
if epoch % unfreeze_every == 0 and epoch > 0:
group_idx = min(epoch // unfreeze_every, len(layer_groups) - 1)
print(f"Unfreezing layer group {group_idx}")
for param in layer_groups[group_idx].parameters():
param.requires_grad = True
# Update optimizer
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
lr=1e-3 / (2 ** group_idx))
# Training
model.train()
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
Discriminative Learning Rates
def get_discriminative_lr_params(model, base_lr=1e-3, lr_mult=2.6):
"""Different learning rates for different layers"""
params = []
# Get all layer names
layer_names = [name for name, _ in model.named_parameters()]
# Group layers
num_layers = len(layer_names)
for idx, (name, param) in enumerate(model.named_parameters()):
# Exponentially decreasing learning rate from top to bottom
layer_lr = base_lr * (lr_mult ** (num_layers - idx - 1))
params.append({'params': param, 'lr': layer_lr})
return params
# Usage
model = models.resnet50(pretrained=True)
params = get_discriminative_lr_params(model)
optimizer = optim.Adam(params)
Adapter Layers
class AdapterLayer(nn.Module):
"""Lightweight adapter for efficient fine-tuning"""
def __init__(self, input_dim, bottleneck_dim=64):
super(AdapterLayer, self).__init__()
self.down_project = nn.Linear(input_dim, bottleneck_dim)
self.up_project = nn.Linear(bottleneck_dim, input_dim)
self.activation = nn.ReLU()
def forward(self, x):
residual = x
x = self.down_project(x)
x = self.activation(x)
x = self.up_project(x)
return x + residual
class ModelWithAdapters(nn.Module):
"""Add adapters to pre-trained model"""
def __init__(self, base_model, adapter_dim=64):
super(ModelWithAdapters, self).__init__()
self.base_model = base_model
# Freeze base model
for param in base_model.parameters():
param.requires_grad = False
# Add adapters after each transformer block
self.adapters = nn.ModuleList([
AdapterLayer(768, adapter_dim) # Assuming 768 hidden dim
for _ in range(12) # For each layer
])
def forward(self, x):
# Forward through base model with adapters
for i, (layer, adapter) in enumerate(zip(self.base_model.layers, self.adapters)):
x = layer(x)
x = adapter(x)
return x
LoRA (Low-Rank Adaptation)
class LoRALayer(nn.Module):
"""Low-Rank Adaptation layer"""
def __init__(self, input_dim, output_dim, rank=4, alpha=1):
super(LoRALayer, self).__init__()
self.rank = rank
self.alpha = alpha
# Low-rank matrices
self.lora_A = nn.Parameter(torch.randn(input_dim, rank) * 0.01)
self.lora_B = nn.Parameter(torch.zeros(rank, output_dim))
self.scaling = alpha / rank
def forward(self, x):
# Low-rank update: x @ A @ B
return (x @ self.lora_A @ self.lora_B) * self.scaling
class LinearWithLoRA(nn.Module):
"""Linear layer with LoRA adaptation"""
def __init__(self, linear_layer, rank=4):
super(LinearWithLoRA, self).__init__()
self.linear = linear_layer
# Freeze original weights
self.linear.weight.requires_grad = False
if self.linear.bias is not None:
self.linear.bias.requires_grad = False
# Add LoRA
self.lora = LoRALayer(
self.linear.in_features,
self.linear.out_features,
rank=rank
)
def forward(self, x):
return self.linear(x) + self.lora(x)
def add_lora_to_model(model, rank=4):
"""Add LoRA to all linear layers"""
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
parent_name = '.'.join(name.split('.')[:-1])
child_name = name.split('.')[-1]
parent = dict(model.named_modules())[parent_name] if parent_name else model
setattr(parent, child_name, LinearWithLoRA(module, rank))
return model
Domain Adaptation
Domain Adversarial Neural Network (DANN)
class GradientReversalLayer(torch.autograd.Function):
"""Gradient reversal layer for domain adaptation"""
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
return -ctx.alpha * grad_output, None
class DomainAdversarialNetwork(nn.Module):
def __init__(self, feature_extractor, num_classes, num_domains=2):
super(DomainAdversarialNetwork, self).__init__()
self.feature_extractor = feature_extractor
# Label classifier
self.label_classifier = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
# Domain classifier
self.domain_classifier = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_domains)
)
def forward(self, x, alpha=1.0):
# Extract features
features = self.feature_extractor(x)
# Label prediction
label_pred = self.label_classifier(features)
# Domain prediction with gradient reversal
reversed_features = GradientReversalLayer.apply(features, alpha)
domain_pred = self.domain_classifier(reversed_features)
return label_pred, domain_pred
def train_dann(model, source_loader, target_loader, num_epochs=50):
"""Train DANN for domain adaptation"""
optimizer = optim.Adam(model.parameters(), lr=0.001)
label_criterion = nn.CrossEntropyLoss()
domain_criterion = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
model.train()
for (source_data, source_labels), (target_data, _) in zip(source_loader, target_loader):
batch_size = source_data.size(0)
# Compute alpha for gradient reversal
p = float(epoch) / num_epochs
alpha = 2. / (1. + np.exp(-10 * p)) - 1
# Source domain
source_label_pred, source_domain_pred = model(source_data, alpha)
source_label_loss = label_criterion(source_label_pred, source_labels)
source_domain_loss = domain_criterion(
source_domain_pred,
torch.zeros(batch_size, dtype=torch.long)
)
# Target domain
_, target_domain_pred = model(target_data, alpha)
target_domain_loss = domain_criterion(
target_domain_pred,
torch.ones(target_data.size(0), dtype=torch.long)
)
# Total loss
loss = source_label_loss + source_domain_loss + target_domain_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
Maximum Mean Discrepancy (MMD)
def mmd_loss(source_features, target_features, kernel='rbf', gamma=1.0):
"""Compute MMD between source and target distributions"""
def gaussian_kernel(x, y, gamma):
"""RBF kernel"""
x_size = x.size(0)
y_size = y.size(0)
dim = x.size(1)
x = x.unsqueeze(1) # (x_size, 1, dim)
y = y.unsqueeze(0) # (1, y_size, dim)
tiled_x = x.expand(x_size, y_size, dim)
tiled_y = y.expand(x_size, y_size, dim)
kernel_input = (tiled_x - tiled_y).pow(2).sum(2)
return torch.exp(-gamma * kernel_input)
# Compute kernels
xx = gaussian_kernel(source_features, source_features, gamma).mean()
yy = gaussian_kernel(target_features, target_features, gamma).mean()
xy = gaussian_kernel(source_features, target_features, gamma).mean()
# MMD
return xx + yy - 2 * xy
class MMDDomainAdaptation(nn.Module):
def __init__(self, feature_extractor, num_classes):
super(MMDDomainAdaptation, self).__init__()
self.feature_extractor = feature_extractor
self.classifier = nn.Linear(512, num_classes)
def forward(self, x):
features = self.feature_extractor(x)
output = self.classifier(features)
return output, features
def train_mmd(model, source_loader, target_loader, num_epochs=50, lambda_mmd=0.1):
"""Train with MMD for domain adaptation"""
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
for (source_data, source_labels), (target_data, _) in zip(source_loader, target_loader):
optimizer.zero_grad()
# Forward pass
source_pred, source_features = model(source_data)
_, target_features = model(target_data)
# Classification loss
class_loss = criterion(source_pred, source_labels)
# MMD loss
mmd = mmd_loss(source_features, target_features)
# Total loss
loss = class_loss + lambda_mmd * mmd
loss.backward()
optimizer.step()
Few-Shot Learning
Prototypical Networks
class PrototypicalNetwork(nn.Module):
def __init__(self, encoder):
super(PrototypicalNetwork, self).__init__()
self.encoder = encoder
def forward(self, support_set, support_labels, query_set, n_way, k_shot):
"""
support_set: (n_way * k_shot, C, H, W)
query_set: (n_query, C, H, W)
"""
# Encode support and query sets
support_embeddings = self.encoder(support_set)
query_embeddings = self.encoder(query_set)
# Compute prototypes (class centroids)
prototypes = []
for c in range(n_way):
class_embeddings = support_embeddings[c * k_shot:(c + 1) * k_shot]
prototype = class_embeddings.mean(dim=0)
prototypes.append(prototype)
prototypes = torch.stack(prototypes)
# Compute distances
distances = torch.cdist(query_embeddings, prototypes)
# Convert to probabilities
log_p_y = F.log_softmax(-distances, dim=1)
return log_p_y
def train_prototypical(model, train_loader, num_episodes=1000, n_way=5, k_shot=5):
"""Train prototypical network"""
optimizer = optim.Adam(model.parameters(), lr=0.001)
for episode in range(num_episodes):
# Sample episode
support_set, support_labels, query_set, query_labels = train_loader.sample_episode(n_way, k_shot)
# Forward pass
log_p_y = model(support_set, support_labels, query_set, n_way, k_shot)
# Loss
loss = F.nll_loss(log_p_y, query_labels)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
if episode % 100 == 0:
print(f"Episode {episode}, Loss: {loss.item():.4f}")
MAML (Model-Agnostic Meta-Learning)
class MAML:
def __init__(self, model, inner_lr=0.01, outer_lr=0.001):
self.model = model
self.inner_lr = inner_lr
self.outer_optimizer = optim.Adam(model.parameters(), lr=outer_lr)
def inner_update(self, support_x, support_y, num_steps=5):
"""Adapt model on support set"""
# Clone model for inner loop
adapted_params = {name: param.clone() for name, param in self.model.named_parameters()}
for step in range(num_steps):
# Forward pass with adapted parameters
logits = self.model(support_x)
loss = F.cross_entropy(logits, support_y)
# Compute gradients
grads = torch.autograd.grad(loss, adapted_params.values(), create_graph=True)
# Update adapted parameters
adapted_params = {
name: param - self.inner_lr * grad
for (name, param), grad in zip(adapted_params.items(), grads)
}
return adapted_params
def meta_update(self, tasks):
"""Meta-update on batch of tasks"""
self.outer_optimizer.zero_grad()
meta_loss = 0
for support_x, support_y, query_x, query_y in tasks:
# Inner loop: adapt to task
adapted_params = self.inner_update(support_x, support_y)
# Outer loop: evaluate on query set
with torch.set_grad_enabled(True):
query_logits = self.model.forward_with_params(query_x, adapted_params)
task_loss = F.cross_entropy(query_logits, query_y)
meta_loss += task_loss
# Meta-gradient step
meta_loss /= len(tasks)
meta_loss.backward()
self.outer_optimizer.step()
return meta_loss.item()
Model Distillation
Knowledge distillation transfers knowledge from large teacher to small student.
class DistillationTrainer:
def __init__(self, teacher_model, student_model, temperature=3.0, alpha=0.5):
self.teacher = teacher_model
self.student = student_model
self.temperature = temperature
self.alpha = alpha
# Freeze teacher
for param in self.teacher.parameters():
param.requires_grad = False
self.teacher.eval()
def distillation_loss(self, student_logits, teacher_logits, labels):
"""Compute distillation loss"""
# Hard loss (student vs true labels)
hard_loss = F.cross_entropy(student_logits, labels)
# Soft loss (student vs teacher)
soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')
soft_loss *= self.temperature ** 2
# Combined loss
loss = self.alpha * hard_loss + (1 - self.alpha) * soft_loss
return loss
def train(self, train_loader, num_epochs=10):
"""Train student with distillation"""
optimizer = optim.Adam(self.student.parameters(), lr=0.001)
for epoch in range(num_epochs):
self.student.train()
total_loss = 0
for images, labels in train_loader:
# Teacher predictions
with torch.no_grad():
teacher_logits = self.teacher(images)
# Student predictions
student_logits = self.student(images)
# Distillation loss
loss = self.distillation_loss(student_logits, teacher_logits, labels)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")
# Example usage
teacher = models.resnet50(pretrained=True)
student = models.resnet18(pretrained=False)
trainer = DistillationTrainer(teacher, student, temperature=3.0, alpha=0.5)
# trainer.train(train_loader, num_epochs=10)
Practical Tips
- Start with Pre-trained Models: Use ImageNet, BERT, GPT weights
- Learning Rate: Use smaller LR for pre-trained layers
- Gradual Unfreezing: Unfreeze layers progressively
- Data Augmentation: Critical when fine-tuning with limited data
- Early Stopping: Monitor validation to prevent overfitting
- Adapter Methods: More efficient than full fine-tuning
Resources
- Hugging Face Transformers: https://huggingface.co/transformers/
- timm (PyTorch Image Models): https://github.com/rwightman/pytorch-image-models
- "Transfer Learning" book by Tan et al.
- Papers with Code Transfer Learning: https://paperswithcode.com/task/transfer-learning
Transformers
Transformers are a type of deep learning model introduced in the paper "Attention is All You Need" by Vaswani et al. in 2017. They have revolutionized the field of natural language processing (NLP) and have been widely adopted in various applications, including machine translation, text summarization, and sentiment analysis.
Key Concepts
-
Attention Mechanism: The core innovation of transformers is the self-attention mechanism, which allows the model to weigh the importance of different words in a sentence when making predictions. This enables the model to capture long-range dependencies and relationships between words more effectively than previous architectures like RNNs and LSTMs.
-
Encoder-Decoder Architecture: The transformer model consists of two main components: the encoder and the decoder. The encoder processes the input data and generates a set of attention-based representations, while the decoder uses these representations to produce the output sequence.
-
Positional Encoding: Since transformers do not have a built-in notion of sequence order (unlike RNNs), they use positional encodings to inject information about the position of each word in the input sequence. This allows the model to understand the order of words.
Attention Mechanism: Deep Dive
The attention mechanism is the heart of the transformer architecture. It allows the model to focus on different parts of the input sequence when processing each element. Let's explore this in detail with mathematical formulations and PyTorch implementations.
Scaled Dot-Product Attention
The fundamental building block of transformer attention is the Scaled Dot-Product Attention. Given three matrices:
- $Q$ (Query): What we're looking for
- $K$ (Key): What we're matching against
- $V$ (Value): The actual information we want to retrieve
The attention mechanism computes:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$
Where:
- $d_k$ is the dimension of the key vectors
- The division by $\sqrt{d_k}$ is scaling to prevent the dot products from growing too large
- softmax normalizes the scores to create a probability distribution
PyTorch Implementation: Scaled Dot-Product Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(query, key, value, mask=None):
"""
Compute scaled dot-product attention.
Args:
query: Query tensor of shape (batch_size, num_heads, seq_len_q, d_k)
key: Key tensor of shape (batch_size, num_heads, seq_len_k, d_k)
value: Value tensor of shape (batch_size, num_heads, seq_len_v, d_v)
mask: Optional mask tensor
Returns:
output: Attention output (batch_size, num_heads, seq_len_q, d_v)
attention_weights: Attention weights (batch_size, num_heads, seq_len_q, seq_len_k)
"""
# Get the dimension of keys
d_k = query.size(-1)
# Step 1: Compute Q @ K^T
# query: (batch, heads, seq_len_q, d_k)
# key.transpose(-2, -1): (batch, heads, d_k, seq_len_k)
# scores: (batch, heads, seq_len_q, seq_len_k)
scores = torch.matmul(query, key.transpose(-2, -1))
print(f"After Q @ K^T - Shape: {scores.shape}")
print(f"Sample scores (first 3x3):\n{scores[0, 0, :3, :3]}\n")
# Step 2: Scale by sqrt(d_k)
scores = scores / math.sqrt(d_k)
print(f"After scaling by √{d_k} = {math.sqrt(d_k):.2f}")
print(f"Scaled scores (first 3x3):\n{scores[0, 0, :3, :3]}\n")
# Step 3: Apply mask (if provided)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
print(f"After masking - Shape: {scores.shape}")
# Step 4: Apply softmax to get attention weights
# Softmax is applied over the last dimension (seq_len_k)
attention_weights = F.softmax(scores, dim=-1)
print(f"Attention weights - Shape: {attention_weights.shape}")
print(f"Sample attention weights (first 3x3):\n{attention_weights[0, 0, :3, :3]}")
print(f"Sum of first row (should be 1.0): {attention_weights[0, 0, 0].sum()}\n")
# Step 5: Multiply by values
# attention_weights: (batch, heads, seq_len_q, seq_len_k)
# value: (batch, heads, seq_len_v, d_v) [seq_len_v == seq_len_k]
# output: (batch, heads, seq_len_q, d_v)
output = torch.matmul(attention_weights, value)
print(f"Final output - Shape: {output.shape}\n")
return output, attention_weights
# Example: Let's trace through a simple example
batch_size = 2
num_heads = 1
seq_len = 4
d_k = 8
d_v = 8
# Create sample tensors
torch.manual_seed(42)
query = torch.randn(batch_size, num_heads, seq_len, d_k)
key = torch.randn(batch_size, num_heads, seq_len, d_k)
value = torch.randn(batch_size, num_heads, seq_len, d_v)
print("="*60)
print("SCALED DOT-PRODUCT ATTENTION - STEP BY STEP")
print("="*60)
print(f"Query shape: {query.shape}")
print(f"Key shape: {key.shape}")
print(f"Value shape: {value.shape}\n")
output, attn_weights = scaled_dot_product_attention(query, key, value)
print(f"Final output shape: {output.shape}")
print(f"Final attention weights shape: {attn_weights.shape}")
Output explanation:
============================================================
SCALED DOT-PRODUCT ATTENTION - STEP BY STEP
============================================================
Query shape: torch.Size([2, 1, 4, 8])
Key shape: torch.Size([2, 1, 4, 8])
Value shape: torch.Size([2, 1, 4, 8])
After Q @ K^T - Shape: torch.Size([2, 1, 4, 4])
Sample scores (first 3x3):
tensor([[ 0.6240, -1.2613, 1.4199],
[-1.8847, 4.0367, -0.5234],
[ 2.1563, -2.5678, 0.8234]])
After scaling by √8 = 2.83
Scaled scores (first 3x3):
tensor([[ 0.2207, -0.4460, 0.5021],
[-0.6661, 1.4271, -0.1850],
[ 0.7624, -0.9078, 0.2911]])
Attention weights - Shape: torch.Size([2, 1, 4, 4])
Sample attention weights (first 3x3):
tensor([[0.2789, 0.1425, 0.3672],
[0.1056, 0.8236, 0.1680],
[0.3924, 0.0731, 0.2458]])
Sum of first row (should be 1.0): 1.0
Final output - Shape: torch.Size([2, 1, 4, 8])
Understanding the Matrix Operations
Let's break down what's happening at each step:
-
Query-Key Dot Product ($QK^T$):
- Each query vector (row in $Q$) is compared against all key vectors (rows in $K$)
- The dot product measures similarity: higher values = more similar
- Shape:
(batch, heads, seq_len_q, d_k) @ (batch, heads, d_k, seq_len_k) → (batch, heads, seq_len_q, seq_len_k)
-
Scaling:
- Dividing by $\sqrt{d_k}$ prevents the dot products from becoming too large
- Large dot products → very small gradients after softmax → slow learning
- This is crucial for stable training
-
Softmax:
- Converts raw scores into a probability distribution
- Each row sums to 1.0
- Higher scores get higher probabilities (attention weights)
-
Weighted Sum (Attention @ Value):
- Uses attention weights to create a weighted combination of value vectors
- Each output position is a mixture of all value vectors
- The weights determine how much each value contributes
Multi-Head Attention
Multi-head attention runs multiple attention operations in parallel, allowing the model to attend to different aspects of the input simultaneously.
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
"""
Multi-Head Attention module.
Args:
d_model: Total dimension of the model (e.g., 512)
num_heads: Number of attention heads (e.g., 8)
"""
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # Dimension per head
# Linear projections for Q, K, V
self.W_q = nn.Linear(d_model, d_model) # Query projection
self.W_k = nn.Linear(d_model, d_model) # Key projection
self.W_v = nn.Linear(d_model, d_model) # Value projection
# Output projection
self.W_o = nn.Linear(d_model, d_model)
def split_heads(self, x):
"""
Split the last dimension into (num_heads, d_k).
Transpose to get shape: (batch_size, num_heads, seq_len, d_k)
"""
batch_size, seq_len, d_model = x.size()
# Reshape to (batch_size, seq_len, num_heads, d_k)
x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
# Transpose to (batch_size, num_heads, seq_len, d_k)
return x.transpose(1, 2)
def combine_heads(self, x):
"""
Inverse of split_heads.
Input: (batch_size, num_heads, seq_len, d_k)
Output: (batch_size, seq_len, d_model)
"""
batch_size, num_heads, seq_len, d_k = x.size()
# Transpose to (batch_size, seq_len, num_heads, d_k)
x = x.transpose(1, 2).contiguous()
# Reshape to (batch_size, seq_len, d_model)
return x.view(batch_size, seq_len, self.d_model)
def forward(self, query, key, value, mask=None):
"""
Forward pass of multi-head attention.
Args:
query: (batch_size, seq_len_q, d_model)
key: (batch_size, seq_len_k, d_model)
value: (batch_size, seq_len_v, d_model)
mask: Optional mask
Returns:
output: (batch_size, seq_len_q, d_model)
attention_weights: (batch_size, num_heads, seq_len_q, seq_len_k)
"""
batch_size = query.size(0)
# Step 1: Linear projections
# Each of these operations: (batch, seq_len, d_model) → (batch, seq_len, d_model)
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)
print(f"\n{'='*60}")
print("MULTI-HEAD ATTENTION - DETAILED STEPS")
print(f"{'='*60}")
print(f"Input shapes - Q: {query.shape}, K: {key.shape}, V: {value.shape}")
print(f"\nAfter linear projections:")
print(f"Q: {Q.shape}, K: {K.shape}, V: {V.shape}")
# Step 2: Split into multiple heads
# (batch, seq_len, d_model) → (batch, num_heads, seq_len, d_k)
Q = self.split_heads(Q)
K = self.split_heads(K)
V = self.split_heads(V)
print(f"\nAfter splitting into {self.num_heads} heads:")
print(f"Q: {Q.shape}, K: {K.shape}, V: {V.shape}")
print(f"Each head has dimension: {self.d_k}")
# Step 3: Scaled dot-product attention
# For each head: (batch, 1, seq_len_q, d_k) with (batch, 1, seq_len_k, d_k)
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
print(f"\nAfter attention computation:")
print(f"Attention scores: {scores.shape}")
print(f"Attention weights: {attention_weights.shape}")
print(f"Attention output: {output.shape}")
# Step 4: Concatenate heads
# (batch, num_heads, seq_len, d_k) → (batch, seq_len, d_model)
output = self.combine_heads(output)
print(f"\nAfter combining heads: {output.shape}")
# Step 5: Final linear projection
# (batch, seq_len, d_model) → (batch, seq_len, d_model)
output = self.W_o(output)
print(f"After final projection: {output.shape}")
print(f"{'='*60}\n")
return output, attention_weights
# Example usage with detailed tracking
d_model = 512
num_heads = 8
batch_size = 2
seq_len = 10
# Create sample input
x = torch.randn(batch_size, seq_len, d_model)
# Initialize multi-head attention
mha = MultiHeadAttention(d_model, num_heads)
# Forward pass (using x for query, key, and value - this is self-attention)
output, attn_weights = mha(x, x, x)
print(f"\nFinal Results:")
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"\nAttention weights for first head, first query position:")
print(attn_weights[0, 0, 0, :]) # Should sum to 1.0
print(f"Sum: {attn_weights[0, 0, 0, :].sum()}")
Visualizing Attention: A Concrete Example
Let's see how attention works on actual text:
import torch
import torch.nn.functional as F
# Simple example: "The cat sat on the mat"
sentence = ["The", "cat", "sat", "on", "the", "mat"]
seq_len = len(sentence)
d_model = 4 # Small for visualization
# Create simple embeddings (normally these would be learned)
# Each word gets a random vector
torch.manual_seed(42)
embeddings = torch.randn(1, seq_len, d_model)
# Simple attention (1 head for clarity)
class SimpleAttention(nn.Module):
def __init__(self, d_model):
super().__init__()
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
def forward(self, x):
Q = self.W_q(x)
K = self.W_k(x)
V = self.W_v(x)
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
return output, attn_weights
# Create and run the attention
attn = SimpleAttention(d_model)
output, weights = attn(embeddings)
# Visualize attention weights
print("Attention Weight Matrix:")
print("(Each row shows where that word 'attends to')\n")
print(" ", " ".join(f"{w:>5}" for w in sentence))
print("-" * 50)
for i, word in enumerate(sentence):
print(f"{word:>7} |", " ".join(f"{weights[0, i, j].item():5.3f}" for j in range(seq_len)))
print("\nInterpretation:")
print("- Each row represents a query word")
print("- Each column represents a key word")
print("- Values show how much the query word 'attends to' each key word")
print("- Higher values = stronger attention")
print("- Each row sums to 1.0")
Masked Attention (for Decoder)
In the decoder, we use masked attention to prevent positions from attending to future positions:
def create_causal_mask(seq_len):
"""
Create a causal mask for decoder self-attention.
Prevents attending to future positions.
Returns a lower triangular matrix:
[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1]]
"""
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimensions
# Example with masking
seq_len = 4
mask = create_causal_mask(seq_len)
print("Causal Mask:")
print(mask[0, 0])
print("\nThis ensures that:")
print("- Position 0 can only see position 0")
print("- Position 1 can see positions 0-1")
print("- Position 2 can see positions 0-2")
print("- Position 3 can see all positions 0-3")
# Apply masked attention
query = torch.randn(1, 1, seq_len, 8)
key = torch.randn(1, 1, seq_len, 8)
value = torch.randn(1, 1, seq_len, 8)
output, attn_weights = scaled_dot_product_attention(query, key, value, mask)
print("\nAttention weights with masking:")
print(attn_weights[0, 0])
print("\nNotice how future positions (upper triangle) have ~0 attention weight")
Complete Self-Attention Layer with PyTorch
Here's a complete implementation you can use in practice:
import torch
import torch.nn as nn
import math
class SelfAttention(nn.Module):
"""
Complete self-attention layer with all components.
"""
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Combined QKV projection (more efficient)
self.qkv_proj = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
batch_size, seq_len, d_model = x.shape
# Project to Q, K, V all at once
qkv = self.qkv_proj(x) # (batch, seq_len, 3 * d_model)
# Split into Q, K, V and reshape for multi-head
qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.d_k)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch, heads, seq_len, d_k)
q, k, v = qkv[0], qkv[1], qkv[2]
# Scaled dot-product attention
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = torch.softmax(scores, dim=-1)
attn = self.dropout(attn)
# Combine heads
out = torch.matmul(attn, v) # (batch, heads, seq_len, d_k)
out = out.transpose(1, 2).contiguous() # (batch, seq_len, heads, d_k)
out = out.reshape(batch_size, seq_len, d_model) # (batch, seq_len, d_model)
# Final projection
out = self.out_proj(out)
return out, attn
# Test the complete implementation
model = SelfAttention(d_model=512, num_heads=8, dropout=0.1)
x = torch.randn(2, 10, 512) # (batch=2, seq_len=10, d_model=512)
output, attention = model(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention shape: {attention.shape}")
Architecture
-
Encoder: The encoder is composed of multiple identical layers, each containing two main sub-layers:
- Multi-Head Self-Attention: This mechanism allows the model to focus on different parts of the input sequence simultaneously, capturing various relationships between words.
- Feed-Forward Neural Network: After the attention mechanism, the output is passed through a feed-forward neural network, which applies a non-linear transformation.
-
Decoder: The decoder also consists of multiple identical layers, with an additional sub-layer for attending to the encoder's output:
- Masked Multi-Head Self-Attention: This prevents the decoder from attending to future tokens in the output sequence during training.
- Encoder-Decoder Attention: This layer allows the decoder to focus on relevant parts of the encoder's output while generating the output sequence.
Complete Transformer Implementation in PyTorch
Here's a full implementation of the transformer architecture with detailed comments on matrix operations:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class PositionalEncoding(nn.Module):
"""
Adds positional information to the input embeddings.
Uses sine and cosine functions of different frequencies.
"""
def __init__(self, d_model, max_len=5000):
super().__init__()
# Create positional encoding matrix
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# Compute the div term for sine and cosine
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
# Apply sine to even indices
pe[:, 0::2] = torch.sin(position * div_term)
# Apply cosine to odd indices
pe[:, 1::2] = torch.cos(position * div_term)
# Add batch dimension: (1, max_len, d_model)
pe = pe.unsqueeze(0)
# Register as buffer (not a parameter, but should be saved with model)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: Tensor of shape (batch_size, seq_len, d_model)
Returns:
x + positional encoding
"""
# Add positional encoding to input
# x: (batch, seq_len, d_model)
# self.pe[:, :x.size(1)]: (1, seq_len, d_model)
return x + self.pe[:, :x.size(1)]
class FeedForward(nn.Module):
"""
Position-wise Feed-Forward Network.
Consists of two linear transformations with ReLU activation.
$$\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2$$
"""
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
Args:
x: (batch_size, seq_len, d_model)
Returns:
output: (batch_size, seq_len, d_model)
"""
# x: (batch, seq_len, d_model)
# After linear1: (batch, seq_len, d_ff)
# After ReLU: (batch, seq_len, d_ff)
# After linear2: (batch, seq_len, d_model)
return self.linear2(self.dropout(F.relu(self.linear1(x))))
class MultiHeadAttentionLayer(nn.Module):
"""
Multi-head attention layer with proper matrix dimension tracking.
"""
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Linear layers for Q, K, V projections
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(self.d_k)
def forward(self, query, key, value, mask=None):
"""
Args:
query: (batch_size, seq_len_q, d_model)
key: (batch_size, seq_len_k, d_model)
value: (batch_size, seq_len_v, d_model)
mask: (batch_size, 1, seq_len_q, seq_len_k) or similar
"""
batch_size = query.size(0)
# Linear projections: (batch, seq_len, d_model) → (batch, seq_len, d_model)
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)
# Reshape for multi-head attention
# (batch, seq_len, d_model) → (batch, seq_len, num_heads, d_k)
Q = Q.view(batch_size, -1, self.num_heads, self.d_k)
K = K.view(batch_size, -1, self.num_heads, self.d_k)
V = V.view(batch_size, -1, self.num_heads, self.d_k)
# Transpose to (batch, num_heads, seq_len, d_k)
Q = Q.transpose(1, 2)
K = K.transpose(1, 2)
V = V.transpose(1, 2)
# Scaled dot-product attention
# Q @ K^T: (batch, num_heads, seq_len_q, d_k) @ (batch, num_heads, d_k, seq_len_k)
# → (batch, num_heads, seq_len_q, seq_len_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
# Apply mask if provided
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax over the last dimension
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# Apply attention to values
# attn_weights @ V: (batch, num_heads, seq_len_q, seq_len_k) @ (batch, num_heads, seq_len_v, d_k)
# → (batch, num_heads, seq_len_q, d_k)
output = torch.matmul(attn_weights, V)
# Concatenate heads
# Transpose: (batch, num_heads, seq_len_q, d_k) → (batch, seq_len_q, num_heads, d_k)
output = output.transpose(1, 2).contiguous()
# Reshape: (batch, seq_len_q, num_heads, d_k) → (batch, seq_len_q, d_model)
output = output.view(batch_size, -1, self.d_model)
# Final linear projection
output = self.W_o(output)
return output, attn_weights
class EncoderLayer(nn.Module):
"""
Single encoder layer consisting of:
1. Multi-head self-attention
2. Add & Norm (residual connection + layer normalization)
3. Feed-forward network
4. Add & Norm
"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttentionLayer(d_model, num_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
Args:
x: (batch_size, seq_len, d_model)
mask: Attention mask
Returns:
output: (batch_size, seq_len, d_model)
"""
# Self-attention with residual connection
attn_output, _ = self.self_attn(x, x, x, mask)
x = x + self.dropout1(attn_output)
x = self.norm1(x)
# Feed-forward with residual connection
ff_output = self.feed_forward(x)
x = x + self.dropout2(ff_output)
x = self.norm2(x)
return x
class DecoderLayer(nn.Module):
"""
Single decoder layer consisting of:
1. Masked multi-head self-attention
2. Add & Norm
3. Multi-head cross-attention (attending to encoder output)
4. Add & Norm
5. Feed-forward network
6. Add & Norm
"""
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttentionLayer(d_model, num_heads, dropout)
self.cross_attn = MultiHeadAttentionLayer(d_model, num_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
"""
Args:
x: Decoder input (batch_size, tgt_seq_len, d_model)
encoder_output: Encoder output (batch_size, src_seq_len, d_model)
src_mask: Source mask for encoder-decoder attention
tgt_mask: Target mask for masked self-attention
"""
# Masked self-attention
attn_output, _ = self.self_attn(x, x, x, tgt_mask)
x = x + self.dropout1(attn_output)
x = self.norm1(x)
# Cross-attention to encoder output
# Query from decoder, Key and Value from encoder
attn_output, _ = self.cross_attn(x, encoder_output, encoder_output, src_mask)
x = x + self.dropout2(attn_output)
x = self.norm2(x)
# Feed-forward
ff_output = self.feed_forward(x)
x = x + self.dropout3(ff_output)
x = self.norm3(x)
return x
class Transformer(nn.Module):
"""
Complete Transformer model for sequence-to-sequence tasks.
"""
def __init__(
self,
src_vocab_size,
tgt_vocab_size,
d_model=512,
num_heads=8,
num_encoder_layers=6,
num_decoder_layers=6,
d_ff=2048,
dropout=0.1,
max_len=5000
):
super().__init__()
# Embeddings
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_len)
# Encoder
self.encoder_layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_encoder_layers)
])
# Decoder
self.decoder_layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_decoder_layers)
])
# Output projection
self.output_projection = nn.Linear(d_model, tgt_vocab_size)
self.dropout = nn.Dropout(dropout)
self.d_model = d_model
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Initialize weights using Xavier initialization."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def make_src_mask(self, src):
"""
Create mask for source sequence (padding mask).
Args:
src: (batch_size, src_seq_len)
Returns:
mask: (batch_size, 1, 1, src_seq_len)
"""
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
return src_mask
def make_tgt_mask(self, tgt):
"""
Create mask for target sequence (padding + causal mask).
Args:
tgt: (batch_size, tgt_seq_len)
Returns:
mask: (batch_size, 1, tgt_seq_len, tgt_seq_len)
"""
tgt_seq_len = tgt.size(1)
# Padding mask
tgt_padding_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)
# Causal mask (prevent attending to future tokens)
tgt_sub_mask = torch.tril(
torch.ones((tgt_seq_len, tgt_seq_len), device=tgt.device)
).bool()
# Combine both masks
tgt_mask = tgt_padding_mask & tgt_sub_mask
return tgt_mask
def encode(self, src, src_mask):
"""
Encode source sequence.
Args:
src: (batch_size, src_seq_len)
src_mask: (batch_size, 1, 1, src_seq_len)
Returns:
encoder_output: (batch_size, src_seq_len, d_model)
"""
# Embedding + Positional encoding
# src: (batch, src_seq_len) → (batch, src_seq_len, d_model)
x = self.src_embedding(src) * math.sqrt(self.d_model)
x = self.positional_encoding(x)
x = self.dropout(x)
# Pass through encoder layers
for layer in self.encoder_layers:
x = layer(x, src_mask)
return x
def decode(self, tgt, encoder_output, src_mask, tgt_mask):
"""
Decode target sequence.
Args:
tgt: (batch_size, tgt_seq_len)
encoder_output: (batch_size, src_seq_len, d_model)
src_mask: (batch_size, 1, 1, src_seq_len)
tgt_mask: (batch_size, 1, tgt_seq_len, tgt_seq_len)
Returns:
decoder_output: (batch_size, tgt_seq_len, d_model)
"""
# Embedding + Positional encoding
x = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
x = self.positional_encoding(x)
x = self.dropout(x)
# Pass through decoder layers
for layer in self.decoder_layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
return x
def forward(self, src, tgt):
"""
Forward pass through the entire transformer.
Args:
src: Source sequence (batch_size, src_seq_len)
tgt: Target sequence (batch_size, tgt_seq_len)
Returns:
output: Logits (batch_size, tgt_seq_len, tgt_vocab_size)
"""
# Create masks
src_mask = self.make_src_mask(src)
tgt_mask = self.make_tgt_mask(tgt)
# Encode
encoder_output = self.encode(src, src_mask)
# Decode
decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
# Project to vocabulary
output = self.output_projection(decoder_output)
return output
# Example usage
if __name__ == "__main__":
# Model hyperparameters
src_vocab_size = 10000
tgt_vocab_size = 10000
d_model = 512
num_heads = 8
num_encoder_layers = 6
num_decoder_layers = 6
d_ff = 2048
dropout = 0.1
# Create model
model = Transformer(
src_vocab_size=src_vocab_size,
tgt_vocab_size=tgt_vocab_size,
d_model=d_model,
num_heads=num_heads,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
d_ff=d_ff,
dropout=dropout
)
# Example input (batch_size=2, sequences of length 10)
src = torch.randint(1, src_vocab_size, (2, 10))
tgt = torch.randint(1, tgt_vocab_size, (2, 12))
print("="*60)
print("TRANSFORMER MODEL SUMMARY")
print("="*60)
print(f"Source sequence shape: {src.shape}")
print(f"Target sequence shape: {tgt.shape}")
print(f"\nModel parameters:")
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
# Forward pass
output = model(src, tgt)
print(f"\nOutput shape: {output.shape}")
print(f"Expected: (batch_size={src.size(0)}, tgt_seq_len={tgt.size(1)}, tgt_vocab_size={tgt_vocab_size})")
# Show dimension flow through the model
print("\n" + "="*60)
print("DIMENSION FLOW THROUGH TRANSFORMER")
print("="*60)
print("\nENCODER:")
print(f"1. Input tokens: {src.shape}")
print(f"2. After embedding: (batch={src.size(0)}, seq={src.size(1)}, d_model={d_model})")
print(f"3. After positional encoding: Same shape")
print(f"4. Through {num_encoder_layers} encoder layers: Same shape")
print(f"5. Encoder output: (batch={src.size(0)}, seq={src.size(1)}, d_model={d_model})")
print("\nDECODER:")
print(f"1. Input tokens: {tgt.shape}")
print(f"2. After embedding: (batch={tgt.size(0)}, seq={tgt.size(1)}, d_model={d_model})")
print(f"3. After positional encoding: Same shape")
print(f"4. Through {num_decoder_layers} decoder layers: Same shape")
print(f"5. After output projection: (batch={tgt.size(0)}, seq={tgt.size(1)}, vocab={tgt_vocab_size})")
print("\n" + "="*60)
Training the Transformer
Here's how you would train this transformer for a translation task:
import torch.optim as optim
# Initialize model, loss, and optimizer
model = Transformer(
src_vocab_size=10000,
tgt_vocab_size=10000,
d_model=512,
num_heads=8,
num_encoder_layers=6,
num_decoder_layers=6,
d_ff=2048,
dropout=0.1
)
criterion = nn.CrossEntropyLoss(ignore_index=0) # Ignore padding token
optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
# Training loop
def train_step(model, src, tgt, optimizer, criterion):
"""
Single training step.
Args:
src: Source sequences (batch_size, src_seq_len)
tgt: Target sequences (batch_size, tgt_seq_len)
"""
model.train()
optimizer.zero_grad()
# Forward pass
# Input to decoder is target shifted right (teacher forcing)
tgt_input = tgt[:, :-1] # Remove last token
tgt_output = tgt[:, 1:] # Remove first token (usually <sos>)
# Get model predictions
# output: (batch_size, tgt_seq_len-1, vocab_size)
output = model(src, tgt_input)
# Reshape for loss calculation
# output: (batch_size * (tgt_seq_len-1), vocab_size)
# tgt_output: (batch_size * (tgt_seq_len-1))
output = output.reshape(-1, output.size(-1))
tgt_output = tgt_output.reshape(-1)
# Calculate loss
loss = criterion(output, tgt_output)
# Backward pass
loss.backward()
# Gradient clipping (prevents exploding gradients)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Update weights
optimizer.step()
return loss.item()
# Example training
for epoch in range(10):
# Generate dummy batch
src = torch.randint(1, 10000, (32, 20)) # batch_size=32, seq_len=20
tgt = torch.randint(1, 10000, (32, 25)) # batch_size=32, seq_len=25
loss = train_step(model, src, tgt, optimizer, criterion)
print(f"Epoch {epoch+1}, Loss: {loss:.4f}")
Inference with the Transformer
def greedy_decode(model, src, max_len, start_token, end_token):
"""
Greedy decoding: always select the most likely next token.
Args:
model: Trained transformer model
src: Source sequence (1, src_seq_len)
max_len: Maximum length of generated sequence
start_token: Start token ID
end_token: End token ID
Returns:
Generated sequence
"""
model.eval()
# Encode the source
src_mask = model.make_src_mask(src)
encoder_output = model.encode(src, src_mask)
# Initialize decoder input with start token
tgt = torch.tensor([[start_token]], device=src.device)
for _ in range(max_len):
# Create target mask
tgt_mask = model.make_tgt_mask(tgt)
# Decode
decoder_output = model.decode(tgt, encoder_output, src_mask, tgt_mask)
# Get predictions for the last token
# decoder_output: (1, current_seq_len, d_model)
# We only need the last position: (1, 1, d_model)
output = model.output_projection(decoder_output[:, -1:, :])
# Get the token with highest probability
# output: (1, 1, vocab_size) → (1, 1)
next_token = output.argmax(dim=-1)
# Append to target sequence
tgt = torch.cat([tgt, next_token], dim=1)
# Stop if we generate the end token
if next_token.item() == end_token:
break
return tgt
# Example inference
src_sequence = torch.randint(1, 10000, (1, 20))
generated = greedy_decode(
model=model,
src=src_sequence,
max_len=50,
start_token=1, # <sos> token
end_token=2 # <eos> token
)
print(f"Generated sequence: {generated}")
print(f"Generated sequence shape: {generated.shape}")
Applications
Transformers have been successfully applied in various domains, including:
-
Natural Language Processing: Models like BERT, GPT, and T5 are based on the transformer architecture and have achieved state-of-the-art results in numerous NLP tasks.
-
Computer Vision: Vision Transformers (ViTs) have adapted the transformer architecture for image classification and other vision tasks, demonstrating competitive performance with traditional convolutional neural networks (CNNs).
-
Speech Processing: Transformers are also being explored for tasks in speech recognition and synthesis, leveraging their ability to model sequential data.
Conclusion
Transformers have transformed the landscape of machine learning, particularly in NLP, by providing a powerful and flexible framework for modeling complex relationships in data. Their ability to handle long-range dependencies and parallelize training has made them a go-to choice for many modern AI applications.
ELI10: What are Transformers?
Transformers are like super-smart assistants that help computers understand and generate human language. Imagine you have a friend who can read a whole book at once and remember everything about it. That's what transformers do! They look at all the words in a sentence and figure out how they relate to each other, which helps them answer questions, translate languages, or even write stories.
Example Usage
- Text Generation: Given a prompt, transformers can generate coherent and contextually relevant text.
- Translation: They can translate sentences from one language to another by understanding the meaning of the words in context.
- Summarization: Transformers can read long articles and provide concise summaries, capturing the main points effectively.
Hugging Face Transformers
Comprehensive guide to using the Hugging Face ecosystem for NLP and beyond.
Table of Contents
- Introduction
- Transformers Library
- Model Hub
- Datasets Library
- Tokenizers
- Training and Fine-tuning
- Inference and Deployment
Introduction
Hugging Face Ecosystem:
- Transformers: State-of-the-art NLP models
- Datasets: Easy access to datasets
- Tokenizers: Fast tokenization
- Accelerate: Distributed training
- Optimum: Hardware optimization
# Installation
pip install transformers datasets tokenizers accelerate
pip install torch torchvision torchaudio # PyTorch
# OR
pip install tensorflow # TensorFlow
Transformers Library
Basic Usage
from transformers import (
AutoTokenizer,
AutoModel,
AutoModelForSequenceClassification,
pipeline
)
# Quick start with pipelines
classifier = pipeline("sentiment-analysis")
result = classifier("I love using Hugging Face!")
print(result)
# [{'label': 'POSITIVE', 'score': 0.9998}]
# Multiple examples
results = classifier([
"I love this!",
"I hate this!",
"This is okay."
])
print(results)
Loading Models and Tokenizers
# Load pre-trained model and tokenizer
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
# Tokenize text
text = "Hello, how are you?"
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
# Get embeddings
with torch.no_grad():
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state
pooler_output = outputs.pooler_output
print(f"Last hidden state shape: {last_hidden_states.shape}")
print(f"Pooler output shape: {pooler_output.shape}")
Common Model Types
# Sequence Classification (e.g., sentiment analysis)
model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased-finetuned-sst-2-english"
)
# Token Classification (e.g., NER)
from transformers import AutoModelForTokenClassification
model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
# Question Answering
from transformers import AutoModelForQuestionAnswering
model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-cased-distilled-squad")
# Text Generation
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2")
# Masked Language Modeling
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
# Sequence-to-Sequence (e.g., translation, summarization)
from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
Pipelines
# Sentiment Analysis
sentiment = pipeline("sentiment-analysis")
print(sentiment("This movie is great!"))
# Named Entity Recognition
ner = pipeline("ner", grouped_entities=True)
print(ner("My name is John and I live in New York"))
# Question Answering
qa = pipeline("question-answering")
context = "The Eiffel Tower is located in Paris, France."
question = "Where is the Eiffel Tower?"
print(qa(question=question, context=context))
# Text Generation
generator = pipeline("text-generation", model="gpt2")
print(generator("Once upon a time", max_length=50, num_return_sequences=2))
# Translation
translator = pipeline("translation_en_to_fr", model="t5-base")
print(translator("Hello, how are you?"))
# Summarization
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
article = """Long article text here..."""
print(summarizer(article, max_length=130, min_length=30))
# Zero-shot Classification
classifier = pipeline("zero-shot-classification")
text = "This is a course about Python programming"
labels = ["education", "politics", "business"]
print(classifier(text, candidate_labels=labels))
# Fill Mask
unmasker = pipeline("fill-mask", model="bert-base-uncased")
print(unmasker("The capital of France is [MASK]."))
# Feature Extraction
feature_extractor = pipeline("feature-extraction")
features = feature_extractor("Hello world!")
print(f"Features shape: {len(features[0])}")
# Image Classification
from transformers import pipeline
image_classifier = pipeline("image-classification", model="google/vit-base-patch16-224")
result = image_classifier("path/to/image.jpg")
# Object Detection
object_detector = pipeline("object-detection", model="facebook/detr-resnet-50")
results = object_detector("path/to/image.jpg")
Custom Pipeline
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
import torch
class CustomSentimentPipeline:
def __init__(self, model_name="distilbert-base-uncased-finetuned-sst-2-english"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.model.eval()
def __call__(self, texts, batch_size=8):
if isinstance(texts, str):
texts = [texts]
results = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
# Tokenize
inputs = self.tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
# Forward pass
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
# Convert to results
for j, prob in enumerate(probs):
label_id = prob.argmax().item()
score = prob[label_id].item()
label = self.model.config.id2label[label_id]
results.append({
'text': batch[j],
'label': label,
'score': score
})
return results
# Usage
custom_pipeline = CustomSentimentPipeline()
results = custom_pipeline(["I love this!", "I hate this!"])
print(results)
Model Hub
Searching and Filtering Models
from huggingface_hub import HfApi, list_models
api = HfApi()
# List models
models = list_models(
filter="text-classification",
sort="downloads",
direction=-1,
limit=10
)
for model in models:
print(f"{model.modelId}: {model.downloads} downloads")
# Search for specific models
models = list_models(search="bert", filter="fill-mask")
for model in models:
print(model.modelId)
Uploading Models
from huggingface_hub import HfApi, create_repo
# Create repository
api = HfApi()
repo_url = create_repo(
repo_id="username/model-name",
token="your_token_here",
private=False
)
# Upload model
api.upload_file(
path_or_fileobj="path/to/model.bin",
path_in_repo="model.bin",
repo_id="username/model-name",
token="your_token_here"
)
# Or use model.push_to_hub()
model.push_to_hub("username/model-name", token="your_token_here")
tokenizer.push_to_hub("username/model-name", token="your_token_here")
Datasets Library
Loading Datasets
from datasets import load_dataset, load_metric
# Load popular datasets
dataset = load_dataset("glue", "mrpc")
print(dataset)
# Load specific split
train_dataset = load_dataset("imdb", split="train")
test_dataset = load_dataset("imdb", split="test")
# Load subset
small_train = load_dataset("imdb", split="train[:1000]")
# Stream large datasets
dataset = load_dataset("c4", "en", streaming=True)
for example in dataset:
print(example)
break
# Load from CSV
dataset = load_dataset("csv", data_files="path/to/file.csv")
# Load from JSON
dataset = load_dataset("json", data_files="path/to/file.json")
# Load from multiple files
dataset = load_dataset(
"json",
data_files={
"train": "train.json",
"test": "test.json"
}
)
Dataset Operations
from datasets import Dataset, DatasetDict
# Create custom dataset
data = {
"text": ["Hello", "World", "!"],
"label": [0, 1, 0]
}
dataset = Dataset.from_dict(data)
# Map function
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
# Filter
filtered_dataset = dataset.filter(lambda x: x["label"] == 1)
# Select
small_dataset = dataset.select(range(100))
# Shuffle
shuffled = dataset.shuffle(seed=42)
# Split
split_dataset = dataset.train_test_split(test_size=0.2)
# Sort
sorted_dataset = dataset.sort("label")
# Add column
dataset = dataset.map(lambda x: {"length": len(x["text"])})
# Remove columns
dataset = dataset.remove_columns(["length"])
# Save and load
dataset.save_to_disk("path/to/save")
loaded_dataset = Dataset.load_from_disk("path/to/save")
Data Collators
from transformers import DataCollatorWithPadding, DataCollatorForLanguageModeling
# Dynamic padding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# For MLM (masked language modeling)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=True,
mlm_probability=0.15
)
# For sequence-to-sequence
from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
# Custom data collator
from dataclasses import dataclass
from typing import Dict, List
import torch
@dataclass
class CustomDataCollator:
tokenizer: AutoTokenizer
def __call__(self, features: List[Dict[str, any]]) -> Dict[str, torch.Tensor]:
batch = {}
# Extract and pad text
texts = [f["text"] for f in features]
tokenized = self.tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt"
)
batch.update(tokenized)
# Add labels
if "label" in features[0]:
batch["labels"] = torch.tensor([f["label"] for f in features])
return batch
Tokenizers
Using Tokenizers
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Basic tokenization
text = "Hello, how are you?"
tokens = tokenizer.tokenize(text)
print(f"Tokens: {tokens}")
# Encode (text to IDs)
input_ids = tokenizer.encode(text)
print(f"Input IDs: {input_ids}")
# Decode (IDs to text)
decoded = tokenizer.decode(input_ids)
print(f"Decoded: {decoded}")
# Full tokenization with special tokens
encoded = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=128,
return_tensors="pt"
)
print(f"Input IDs shape: {encoded['input_ids'].shape}")
print(f"Attention mask shape: {encoded['attention_mask'].shape}")
# Batch tokenization
texts = ["Hello!", "How are you?", "Nice to meet you."]
batch_encoded = tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt"
)
# Token type IDs (for sentence pairs)
text_a = "This is sentence A"
text_b = "This is sentence B"
encoded = tokenizer(text_a, text_b, return_tensors="pt")
print(f"Token type IDs: {encoded['token_type_ids']}")
Fast Tokenizers
from tokenizers import Tokenizer
from tokenizers.models import BPE, WordPiece
from tokenizers.trainers import BpeTrainer, WordPieceTrainer
from tokenizers.pre_tokenizers import Whitespace
# Create BPE tokenizer
tokenizer = Tokenizer(BPE())
tokenizer.pre_tokenizer = Whitespace()
# Train tokenizer
trainer = BpeTrainer(vocab_size=30000, special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"])
files = ["path/to/text1.txt", "path/to/text2.txt"]
tokenizer.train(files, trainer)
# Save tokenizer
tokenizer.save("path/to/tokenizer.json")
# Load tokenizer
from transformers import PreTrainedTokenizerFast
fast_tokenizer = PreTrainedTokenizerFast(tokenizer_file="path/to/tokenizer.json")
Training and Fine-tuning
Using Trainer API
from transformers import Trainer, TrainingArguments, AutoModelForSequenceClassification
from datasets import load_dataset, load_metric
import numpy as np
# Load dataset and model
dataset = load_dataset("glue", "mrpc")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Tokenize dataset
def tokenize_function(examples):
return tokenizer(examples["sentence1"], examples["sentence2"], padding="max_length", truncation=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
# Define metrics
metric = load_metric("accuracy")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
# Training arguments
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=3,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
logging_dir="./logs",
logging_steps=100,
save_total_limit=2,
fp16=True, # Mixed precision training
dataloader_num_workers=4,
gradient_accumulation_steps=2,
warmup_steps=500,
)
# Create trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
compute_metrics=compute_metrics,
)
# Train
trainer.train()
# Evaluate
results = trainer.evaluate()
print(results)
# Predict
predictions = trainer.predict(tokenized_datasets["test"])
print(predictions.metrics)
# Save model
trainer.save_model("./final_model")
Custom Training Loop
from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm
def train_custom(model, train_dataset, eval_dataset, num_epochs=3):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=16)
# Optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=2e-5)
total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=0.1 * total_steps,
num_training_steps=total_steps
)
# Training loop
for epoch in range(num_epochs):
model.train()
total_loss = 0
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
for batch in progress_bar:
# Move to device
batch = {k: v.to(device) for k, v in batch.items()}
# Forward pass
outputs = model(**batch)
loss = outputs.loss
# Backward pass
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
total_loss += loss.item()
progress_bar.set_postfix({"loss": loss.item()})
# Evaluation
model.eval()
eval_loss = 0
predictions, true_labels = [], []
with torch.no_grad():
for batch in eval_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
eval_loss += outputs.loss.item()
logits = outputs.logits
preds = torch.argmax(logits, dim=-1)
predictions.extend(preds.cpu().numpy())
true_labels.extend(batch["labels"].cpu().numpy())
# Compute metrics
accuracy = np.mean(np.array(predictions) == np.array(true_labels))
print(f"Epoch {epoch+1}:")
print(f" Train Loss: {total_loss/len(train_loader):.4f}")
print(f" Eval Loss: {eval_loss/len(eval_loader):.4f}")
print(f" Accuracy: {accuracy:.4f}")
return model
Inference and Deployment
Optimized Inference
# Model optimization
from transformers import pipeline
from optimum.onnxruntime import ORTModelForSequenceClassification, ORTOptimizer
from optimum.onnxruntime.configuration import OptimizationConfig
# Convert to ONNX
model = ORTModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased-finetuned-sst-2-english",
from_transformers=True
)
# Optimize
optimizer = ORTOptimizer.from_pretrained(model)
optimization_config = OptimizationConfig(optimization_level=2)
optimizer.optimize(save_dir="optimized_model", optimization_config=optimization_config)
# Use optimized model
optimized_pipeline = pipeline(
"sentiment-analysis",
model=model,
tokenizer=tokenizer
)
# Quantization
from optimum.onnxruntime import ORTQuantizer
from optimum.onnxruntime.configuration import AutoQuantizationConfig
quantizer = ORTQuantizer.from_pretrained(model)
qconfig = AutoQuantizationConfig.avx512_vnni(is_static=False)
quantizer.quantize(save_dir="quantized_model", quantization_config=qconfig)
Batch Inference
def batch_predict(texts, model, tokenizer, batch_size=32):
"""Efficient batch prediction"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
all_predictions = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
# Tokenize
inputs = tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Predict
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
all_predictions.extend(predictions.cpu().numpy())
return all_predictions
# Usage
texts = ["text1", "text2", "text3"] * 1000
predictions = batch_predict(texts, model, tokenizer)
API Deployment
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline
app = FastAPI()
# Load model once
classifier = pipeline("sentiment-analysis")
class TextRequest(BaseModel):
text: str
class PredictionResponse(BaseModel):
label: str
score: float
@app.post("/predict", response_model=PredictionResponse)
def predict(request: TextRequest):
result = classifier(request.text)[0]
return PredictionResponse(label=result['label'], score=result['score'])
# Run with: uvicorn app:app --host 0.0.0.0 --port 8000
Practical Tips
- Model Selection: Choose based on task, speed, and accuracy requirements
- Tokenization: Handle special characters and multiple languages carefully
- Batch Size: Adjust based on GPU memory
- Mixed Precision: Use fp16 for faster training
- Gradient Accumulation: Simulate larger batch sizes
- Model Evaluation: Use appropriate metrics for your task
Resources
- Hugging Face Documentation: https://huggingface.co/docs
- Course: https://huggingface.co/course
- Model Hub: https://huggingface.co/models
- Datasets Hub: https://huggingface.co/datasets
- Forums: https://discuss.huggingface.co/
PyTorch
Overview
PyTorch is a deep learning framework developed by Meta (Facebook) that provides:
- Dynamic computation graphs: Build networks on-the-fly (unlike static graphs in TensorFlow)
- Pythonic API: Natural, intuitive syntax for building neural networks
- GPU acceleration: Seamless CUDA support for fast training
- Rich ecosystem: Tools for NLP, computer vision, reinforcement learning
- Production ready: Deploy with TorchScript, ONNX, or mobile
Installation
# CPU only
pip install torch torchvision torchaudio
# GPU (CUDA 11.8)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# GPU (CUDA 12.1)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# Check installation
python -c "import torch; print(torch.__version__); print(torch.cuda.is_available())"
Core Concepts
Tensors
Tensors are the fundamental building blocks - N-dimensional arrays:
import torch
# Creating tensors
t1 = torch.tensor([1, 2, 3]) # From list
t2 = torch.zeros(3, 4) # Zeros tensor
t3 = torch.ones(2, 3) # Ones tensor
t4 = torch.randn(3, 4) # Random normal distribution
t5 = torch.arange(0, 10, 2) # Range: [0, 2, 4, 6, 8]
# Tensor properties
print(t1.shape) # torch.Size([3])
print(t1.dtype) # torch.int64
print(t1.device) # cpu
# Move to GPU
if torch.cuda.is_available():
t1 = t1.cuda() # or t1.to('cuda')
print(t1.device) # cuda:0
# Tensor operations
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])
c = a + b # Element-wise addition
d = a * b # Element-wise multiplication
e = torch.dot(a, b) # Dot product: 32.0
f = torch.matmul(a.view(3, 1), b.view(1, 3)) # Matrix multiplication
# Reshaping
x = torch.randn(2, 3, 4)
y = x.view(6, 4) # Reshape to (6, 4)
z = x.reshape(-1) # Flatten (auto-infer dimension)
Autograd (Automatic Differentiation)
PyTorch computes gradients automatically:
import torch
# Enable gradient tracking
x = torch.tensor([2.0, 3.0], requires_grad=True)
y = torch.tensor([1.0, 2.0], requires_grad=True)
# Forward pass
z = x.pow(2).sum() + (y * x).sum() # z = x^2 + y*x
# Backward pass (compute gradients)
z.backward()
print(x.grad) # dz/dx
print(y.grad) # dz/dy
# Example: dz/dx = 2*x + y = [5, 8] for x=[2,3], y=[1,2]
Neural Network Building
import torch
import torch.nn as nn
import torch.nn.functional as F
# Define a simple network
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(784, 128) # Input: 28*28=784, Output: 128
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10) # 10 output classes
def forward(self, x):
x = x.view(x.size(0), -1) # Flatten: (batch, 784)
x = F.relu(self.fc1(x)) # ReLU activation
x = F.relu(self.fc2(x))
x = self.fc3(x) # No activation (raw logits)
return x
# Create model and move to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleNet().to(device)
# Check model architecture
print(model)
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")
Datasets and DataLoaders
Custom Dataset
Create custom datasets by inheriting from torch.utils.data.Dataset:
from torch.utils.data import Dataset, DataLoader
import torch
class CustomDataset(Dataset):
def __init__(self, data, labels, transform=None):
"""
Args:
data: List or array of inputs
labels: List or array of labels
transform: Optional transformations to apply
"""
self.data = data
self.labels = labels
self.transform = transform
def __len__(self):
"""Return total number of samples"""
return len(self.data)
def __getitem__(self, idx):
"""Return sample at index idx"""
sample = self.data[idx]
label = self.labels[idx]
if self.transform:
sample = self.transform(sample)
return sample, label
# Usage
X = torch.randn(1000, 28, 28) # 1000 images of 28x28
y = torch.randint(0, 10, (1000,)) # 1000 labels (10 classes)
dataset = CustomDataset(X, y)
print(f"Dataset size: {len(dataset)}")
sample, label = dataset[0]
print(f"Sample shape: {sample.shape}, Label: {label}")
Image Dataset with Transforms
from torchvision import transforms
from PIL import Image
class ImageDataset(Dataset):
def __init__(self, image_paths, labels, transform=None):
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# Load image
image = Image.open(self.image_paths[idx]).convert('RGB')
# Apply transforms
if self.transform:
image = self.transform(image)
label = self.labels[idx]
return image, label
# Define transforms
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# Create datasets
train_dataset = ImageDataset(train_paths, train_labels, transform=train_transform)
test_dataset = ImageDataset(test_paths, test_labels, transform=test_transform)
Built-in Datasets
PyTorch provides common datasets in torchvision.datasets:
from torchvision import datasets, transforms
# MNIST
mnist_train = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transforms.ToTensor()
)
mnist_test = datasets.MNIST(
root='./data',
train=False,
download=True,
transform=transforms.ToTensor()
)
# CIFAR-10
cifar10 = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transforms.ToTensor()
)
# ImageNet (large, requires manual download)
imagenet = datasets.ImageNet(
root='./data',
split='train',
transform=transforms.ToTensor()
)
# Print dataset info
print(f"Dataset size: {len(mnist_train)}")
sample, label = mnist_train[0]
print(f"Sample shape: {sample.shape}, Label: {label}")
DataLoader
DataLoader handles batching, shuffling, and parallel loading:
from torch.utils.data import DataLoader
# Create DataLoader
train_loader = DataLoader(
dataset=train_dataset,
batch_size=32, # Samples per batch
shuffle=True, # Shuffle order every epoch
num_workers=4, # Parallel workers for data loading
pin_memory=True, # Pin memory for faster GPU transfer
drop_last=True # Drop last incomplete batch
)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=32,
shuffle=False, # Don't shuffle test data
num_workers=4,
pin_memory=True,
drop_last=False
)
# Iterate through batches
for batch_idx, (batch_x, batch_y) in enumerate(train_loader):
print(f"Batch {batch_idx}")
print(f" Input shape: {batch_x.shape}") # (32, 1, 28, 28)
print(f" Labels shape: {batch_y.shape}") # (32,)
if batch_idx == 0:
break
Data Splits
from torch.utils.data import random_split
# Original dataset
dataset = CustomDataset(X, y)
# Split into train (70%), val (15%), test (15%)
total_size = len(dataset)
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size
train_set, val_set, test_set = random_split(
dataset,
[train_size, val_size, test_size]
)
# Create loaders
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False)
Data Augmentation Strategies
from torchvision import transforms
# For images
augmentation = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.2),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.GaussianBlur(kernel_size=3),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
# For text (custom)
class TextAugmentation:
def __init__(self, vocab_size=10000):
self.vocab_size = vocab_size
def __call__(self, tokens):
# Random dropout of tokens
if torch.rand(1) > 0.5:
mask = torch.rand(len(tokens)) > 0.1
tokens = tokens[mask]
return tokens
# Custom augmentation
class MixupAugmentation:
def __init__(self, alpha=1.0):
self.alpha = alpha
def __call__(self, batch_x, batch_y):
"""Mixup data augmentation"""
lam = torch.distributions.Beta(self.alpha, self.alpha).sample()
batch_size = batch_x.size(0)
index = torch.randperm(batch_size)
mixed_x = lam * batch_x + (1 - lam) * batch_x[index]
mixed_y = lam * batch_y.float() + (1 - lam) * batch_y[index].float()
return mixed_x, mixed_y
DataLoader Performance Tips
# Good configuration
loader = DataLoader(
dataset,
batch_size=64, # Larger batches for efficiency
shuffle=True,
num_workers=4, # Use multiple workers (2-4 per GPU)
pin_memory=True, # Pin to CPU memory for GPU transfer
persistent_workers=True, # Keep workers alive between epochs
prefetch_factor=2 # Prefetch batches (2-4 recommended)
)
# Monitor data loading performance
import time
start = time.time()
for batch in loader:
pass
elapsed = time.time() - start
print(f"Time to load {len(loader)} batches: {elapsed:.2f}s")
# If loading is slow:
# - Increase num_workers
# - Check disk speed (SSD vs HDD)
# - Use pin_memory=True
# - Reduce image resolution if possible
# - Use data compression
Combining Datasets
from torch.utils.data import ConcatDataset, Subset
# Concatenate multiple datasets
combined_dataset = ConcatDataset([dataset1, dataset2, dataset3])
# Subset of dataset
indices = list(range(0, 100)) # First 100 samples
subset = Subset(dataset, indices)
# Weighted sampling (e.g., for imbalanced data)
from torch.utils.data import WeightedRandomSampler
weights = [1.0 if label == 0 else 10.0 for label in dataset.labels]
sampler = WeightedRandomSampler(weights, len(dataset), replacement=True)
loader = DataLoader(
dataset,
batch_size=32,
sampler=sampler # Use sampler instead of shuffle
)
Training Loop
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
# Dummy data
X_train = torch.randn(1000, 784)
y_train = torch.randint(0, 10, (1000,))
# Create dataloader
dataset = TensorDataset(X_train, y_train)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# Model, loss, optimizer
model = SimpleNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
num_epochs = 5
for epoch in range(num_epochs):
total_loss = 0
for batch_x, batch_y in dataloader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
# Forward pass
logits = model(batch_x)
loss = criterion(logits, batch_y)
# Backward pass
optimizer.zero_grad() # Clear old gradients
loss.backward() # Compute new gradients
optimizer.step() # Update parameters
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
Convolutional Neural Networks
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# Input: (batch, 3, 32, 32) - 3 channels, 32x32 images
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 8 * 8, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
# Conv block 1
x = self.conv1(x) # (batch, 32, 32, 32)
x = F.relu(x)
x = self.pool(x) # (batch, 32, 16, 16)
# Conv block 2
x = self.conv2(x) # (batch, 64, 16, 16)
x = F.relu(x)
x = self.pool(x) # (batch, 64, 8, 8)
# Flatten and FC layers
x = x.view(x.size(0), -1) # (batch, 64*8*8)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
model = CNN().to(device)
Recurrent Neural Networks
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True, # Input shape: (batch, seq_len, input_size)
dropout=0.5
)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
# x shape: (batch, seq_len, input_size)
lstm_out, (h_n, c_n) = self.lstm(x)
# lstm_out: (batch, seq_len, hidden_size)
# h_n: (num_layers, batch, hidden_size) - final hidden state
# Use last hidden state for classification
last_hidden = h_n[-1] # (batch, hidden_size)
out = self.fc(last_hidden) # (batch, output_size)
return out
model = RNN(input_size=100, hidden_size=256, num_layers=2, output_size=10).to(device)
Model Evaluation
# Evaluation mode (disables dropout, batch norm uses running stats)
model.eval()
correct = 0
total = 0
with torch.no_grad(): # Disable gradient computation
for batch_x, batch_y in test_dataloader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
logits = model(batch_x)
predictions = torch.argmax(logits, dim=1)
correct += (predictions == batch_y).sum().item()
total += batch_y.size(0)
accuracy = correct / total
print(f"Accuracy: {accuracy:.4f}")
# Switch back to training mode
model.train()
Saving and Loading Models
# Save model
torch.save(model.state_dict(), 'model.pth')
# Load model
model = SimpleNet().to(device)
model.load_state_dict(torch.load('model.pth'))
# Save entire checkpoint
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}
torch.save(checkpoint, 'checkpoint.pth')
# Load checkpoint
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
Common Optimizers
import torch.optim as optim
# SGD with momentum
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# Adam (adaptive learning rate)
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
# RMSprop
optimizer = optim.RMSprop(model.parameters(), lr=0.01, alpha=0.99)
# Learning rate scheduling
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
# In training loop:
for epoch in range(num_epochs):
# ... training code ...
scheduler.step() # Decay learning rate
Loss Functions
# Classification
criterion = nn.CrossEntropyLoss() # Combines LogSoftmax + NLLLoss
criterion = nn.BCEWithLogitsLoss() # Binary classification
# Regression
criterion = nn.MSELoss() # Mean Squared Error
criterion = nn.L1Loss() # Mean Absolute Error
criterion = nn.SmoothL1Loss() # Huber loss
# Custom loss
class CustomLoss(nn.Module):
def forward(self, pred, target):
return (pred - target).pow(2).mean()
Advanced Techniques
Batch Normalization
class BNNetwork(nn.Module):
def __init__(self):
super(BNNetwork, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.bn1 = nn.BatchNorm1d(256) # Normalize features
self.fc2 = nn.Linear(256, 128)
self.bn2 = nn.BatchNorm1d(128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = self.bn1(x) # Normalize after linear layer
x = F.relu(x)
x = self.fc2(x)
x = self.bn2(x)
x = F.relu(x)
x = self.fc3(x)
return x
Gradient Clipping
# Prevent exploding gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.1)
Mixed Precision Training
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch_x, batch_y in dataloader:
optimizer.zero_grad()
with autocast(): # Automatically cast to float16 where safe
logits = model(batch_x)
loss = criterion(logits, batch_y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Time Complexity
| Operation | Time Complexity |
|---|---|
| Forward pass | O(n * hidden_size) for dense layers |
| Backward pass | O(n * hidden_size) (2-3x forward) |
| Conv2D | O(H * W * C_in * K^2) per sample |
| LSTM | O(seq_len * hidden_size^2) per sample |
Best Practices
- Use DataLoader for batching and shuffling
- Track metrics with tensorboard or wandb
- Use gradient clipping for unstable training
- Normalize inputs (mean=0, std=1)
- Monitor learning - plot loss and metrics
- Save checkpoints periodically during training
- Use model.eval() during validation/testing
- Pin memory for faster data loading:
DataLoader(..., pin_memory=True)
Common Issues
Out of Memory
# Solution 1: Reduce batch size
batch_size = 16 # Instead of 32
# Solution 2: Gradient accumulation
accumulation_steps = 4
for i, (batch_x, batch_y) in enumerate(dataloader):
logits = model(batch_x)
loss = criterion(logits, batch_y) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
NaN Loss
- Learning rate too high
- Batch normalization issues
- Unstable loss function
- Check for gradient clipping
Slow Training
- Use GPU (move model and data to CUDA)
- Increase batch size
- Use mixed precision training
- Profile with
torch.profiler
ELI10
PyTorch is like a smart building assistant:
- You design the blueprint (define the network architecture)
- PyTorch remembers every step (autograd tracks all operations)
- You show examples (training data)
- PyTorch automatically learns (backpropagation adjusts weights)
- It gets better each time (more epochs = better performance)
It's like learning to cook - you follow the recipe, taste the result, adjust ingredients, and get better over time!
Further Resources
- PyTorch Official Documentation
- PyTorch Tutorials
- Deep Learning Specialization with PyTorch
- PyTorch Lightning - High-level wrapper
- Hugging Face Transformers - NLP with PyTorch
- Fast.ai - Practical deep learning course
NumPy for Machine Learning
NumPy is the foundational numerical computing library for Python and forms the backbone of the ML/AI ecosystem. Understanding NumPy deeply is essential for efficient machine learning implementations.
Table of Contents
- Why NumPy for ML
- Array Creation Patterns
- Indexing and Slicing
- Broadcasting
- Vectorization
- Reshaping and Transformations
- Matrix Operations
- Statistical Operations
- Linear Algebra
- Random Number Generation
- Advanced Patterns
- Performance Optimization
- Common ML Patterns
Why NumPy for ML
Speed: NumPy operations are implemented in C and are vectorized, making them 10-100x faster than pure Python loops.
Memory Efficiency: Contiguous memory layout and fixed data types reduce overhead.
Foundation: PyTorch, TensorFlow, and scikit-learn all build on NumPy conventions.
Broadcasting: Implicit expansion of arrays enables concise, efficient code.
import numpy as np
# Pure Python (slow)
result = []
for i in range(1000000):
result.append(i ** 2)
# NumPy (fast)
result = np.arange(1000000) ** 2
Array Creation Patterns
Basic Creation
# From lists
arr = np.array([1, 2, 3, 4, 5])
matrix = np.array([[1, 2, 3], [4, 5, 6]])
# Specify dtype for memory efficiency
arr_int8 = np.array([1, 2, 3], dtype=np.int8) # 1 byte per element
arr_float32 = np.array([1, 2, 3], dtype=np.float32) # 4 bytes per element
arr_float64 = np.array([1, 2, 3], dtype=np.float64) # 8 bytes per element (default)
Initialization Patterns for ML
# Zeros - common for initializing gradients or counts
zeros = np.zeros((3, 4))
zeros_like = np.zeros_like(existing_array)
# Ones - useful for bias initialization
ones = np.ones((3, 4))
ones_like = np.ones_like(existing_array)
# Empty - fastest, doesn't initialize (use when you'll overwrite)
empty = np.empty((3, 4))
# Full - initialize with specific value
full = np.full((3, 4), 0.01) # Initialize all to 0.01
# Identity matrix - common in linear algebra
identity = np.eye(5)
identity_offset = np.eye(5, k=1) # Offset diagonal
# Ranges
arange = np.arange(0, 10, 2) # [0, 2, 4, 6, 8]
linspace = np.linspace(0, 1, 5) # [0.0, 0.25, 0.5, 0.75, 1.0]
logspace = np.logspace(0, 2, 5) # [1, 10, 100] logarithmically spaced
# Meshgrid - useful for coordinate generation
x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y) # Create 2D coordinate grids
Random Initialization (Modern API)
# Modern way (NumPy 1.17+)
rng = np.random.default_rng(seed=42)
# Uniform distribution [0, 1)
uniform = rng.random((3, 4))
# Normal/Gaussian distribution
normal = rng.normal(loc=0, scale=1, size=(3, 4))
# Xavier/Glorot initialization for neural networks
n_in, n_out = 784, 256
xavier = rng.normal(0, np.sqrt(2 / (n_in + n_out)), (n_in, n_out))
# He initialization (for ReLU networks)
he = rng.normal(0, np.sqrt(2 / n_in), (n_in, n_out))
# Integer random values
randint = rng.integers(0, 10, size=(3, 4))
# Choice (sampling)
choices = rng.choice([1, 2, 3, 4, 5], size=10, replace=True)
Indexing and Slicing
Basic Indexing
arr = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
# Single element
arr[0] # 0
arr[-1] # 9
# Slicing: [start:stop:step]
arr[2:5] # [2, 3, 4]
arr[::2] # [0, 2, 4, 6, 8] - every second element
arr[::-1] # [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] - reverse
arr[5:] # [5, 6, 7, 8, 9]
arr[:5] # [0, 1, 2, 3, 4]
Multidimensional Indexing
matrix = np.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# Element access
matrix[0, 0] # 1
matrix[1, 2] # 6
# Row and column slicing
matrix[0, :] # [1, 2, 3] - first row
matrix[:, 0] # [1, 4, 7] - first column
matrix[:2, :2] # [[1, 2], [4, 5]] - top-left 2x2
# Stride tricks
matrix[::2, ::2] # Every other row and column
Boolean Indexing (Critical for ML)
arr = np.array([1, -2, 3, -4, 5, -6])
# Boolean mask
mask = arr > 0
positive = arr[mask] # [1, 3, 5]
# Inline
positive = arr[arr > 0]
even = arr[arr % 2 == 0]
# Compound conditions
arr[(arr > 0) & (arr < 4)] # [1, 3]
arr[(arr < 0) | (arr > 4)] # [-2, -4, 5, -6]
# Filtering outliers
data = np.random.randn(1000)
mean, std = data.mean(), data.std()
filtered = data[np.abs(data - mean) < 2 * std] # Remove outliers beyond 2 sigma
# Setting values with boolean indexing
arr[arr < 0] = 0 # Clip negative values to 0 (ReLU activation!)
Fancy Indexing
arr = np.array([10, 20, 30, 40, 50])
# Index with array of integers
indices = np.array([0, 2, 4])
arr[indices] # [10, 30, 50]
# Multidimensional fancy indexing
matrix = np.arange(12).reshape(3, 4)
rows = np.array([0, 2, 2])
cols = np.array([1, 3, 0])
matrix[rows, cols] # Elements at (0,1), (2,3), (2,0)
# Batch indexing (common in ML)
batch = np.random.randn(32, 10) # 32 samples, 10 classes
labels = np.array([3, 1, 5, ...]) # True class for each sample
selected_logits = batch[np.arange(32), labels] # Logits for true classes
Advanced Slicing Tricks
# Ellipsis (...) - all remaining dimensions
tensor = np.random.randn(2, 3, 4, 5)
tensor[0, ...] # Same as tensor[0, :, :, :]
tensor[..., 0] # Same as tensor[:, :, :, 0]
# np.newaxis or None - add dimension
arr = np.array([1, 2, 3])
arr[:, np.newaxis] # Shape (3, 1) - column vector
arr[np.newaxis, :] # Shape (1, 3) - row vector
Broadcasting
Broadcasting allows NumPy to perform operations on arrays of different shapes efficiently without copying data.
Broadcasting Rules
- If arrays have different dimensions, pad the smaller shape with ones on the left
- Arrays are compatible if, for each dimension, the sizes are equal or one of them is 1
- After broadcasting, each dimension becomes the maximum of the two
# Rule visualization
A: (3, 4, 5)
B: (1, 5)
Result:(3, 4, 5)
A: (3, 1, 5)
B: (3, 4, 1)
Result:(3, 4, 5)
Common Broadcasting Patterns
# Scalar with array
arr = np.array([1, 2, 3, 4])
arr * 2 # [2, 4, 6, 8]
# 1D with 2D (very common in ML)
matrix = np.array([[1, 2, 3],
[4, 5, 6]])
row_vector = np.array([10, 20, 30])
matrix + row_vector
# [[11, 22, 33],
# [14, 25, 36]]
# Broadcasting for normalization
data = np.random.randn(100, 5) # 100 samples, 5 features
mean = data.mean(axis=0) # Shape (5,)
std = data.std(axis=0) # Shape (5,)
normalized = (data - mean) / std # Broadcasting happens automatically
# Column vector broadcasting
col_vector = np.array([[1], [2], [3]]) # Shape (3, 1)
row_vector = np.array([10, 20, 30]) # Shape (3,)
result = col_vector + row_vector
# [[11, 21, 31],
# [12, 22, 32],
# [13, 23, 33]]
Practical ML Examples
# Batch normalization
batch = np.random.randn(32, 64, 64, 3) # 32 images, 64x64, 3 channels
mean = batch.mean(axis=(0, 1, 2), keepdims=True) # Shape (1, 1, 1, 3)
std = batch.std(axis=(0, 1, 2), keepdims=True)
normalized_batch = (batch - mean) / (std + 1e-8)
# Distance matrix computation
X = np.random.randn(100, 50) # 100 samples, 50 features
# Pairwise squared distances using broadcasting
X_expanded = X[:, np.newaxis, :] # Shape (100, 1, 50)
X2_expanded = X[np.newaxis, :, :] # Shape (1, 100, 50)
distances = np.sum((X_expanded - X2_expanded) ** 2, axis=2) # (100, 100)
# Attention mechanism (simplified)
Q = np.random.randn(10, 64) # 10 queries, 64 dims
K = np.random.randn(20, 64) # 20 keys, 64 dims
# Compute attention scores
scores = Q @ K.T # (10, 20)
Broadcasting Pitfalls
# Unintended broadcasting
a = np.random.randn(3, 1)
b = np.random.randn(4, 1)
# a + b raises error - shapes (3,1) and (4,1) incompatible
# Accidental dimension loss
a = np.random.randn(5, 1)
b = a.flatten() # Shape (5,) not (5, 1)
# Now b broadcasts differently!
# Always check shapes
assert a.shape == expected_shape, f"Shape mismatch: {a.shape}"
Vectorization
Vectorization is the process of replacing explicit loops with array operations. It's fundamental to writing efficient NumPy code.
Why Vectorization Matters
import time
# Non-vectorized
data = list(range(1000000))
start = time.time()
result = [x ** 2 for x in data]
print(f"Loop: {time.time() - start:.4f}s")
# Vectorized
data = np.arange(1000000)
start = time.time()
result = data ** 2
print(f"Vectorized: {time.time() - start:.4f}s")
# Typically 50-100x faster!
Basic Vectorization Patterns
# Element-wise operations (automatically vectorized)
a = np.array([1, 2, 3, 4])
b = np.array([10, 20, 30, 40])
a + b # [11, 22, 33, 44]
a * b # [10, 40, 90, 160]
a ** b # [1, 1048576, ...]
np.sin(a) # [0.841, 0.909, 0.141, -0.757]
np.exp(a) # [2.718, 7.389, 20.085, 54.598]
# Comparison operators
a > 2 # [False, False, True, True]
np.maximum(a, 2) # [2, 2, 3, 4] - element-wise max
Replacing Loops with Vectorization
# Example 1: Sigmoid activation
def sigmoid_loop(x):
result = np.zeros_like(x)
for i in range(len(x)):
result[i] = 1 / (1 + np.exp(-x[i]))
return result
def sigmoid_vectorized(x):
return 1 / (1 + np.exp(-x))
# Example 2: Pairwise distances
def distances_loop(X, Y):
n, m = len(X), len(Y)
distances = np.zeros((n, m))
for i in range(n):
for j in range(m):
distances[i, j] = np.sqrt(np.sum((X[i] - Y[j]) ** 2))
return distances
def distances_vectorized(X, Y):
# Using broadcasting
return np.sqrt(np.sum((X[:, np.newaxis] - Y[np.newaxis, :]) ** 2, axis=2))
# Example 3: Moving average
def moving_average_loop(arr, window):
result = np.zeros(len(arr) - window + 1)
for i in range(len(result)):
result[i] = arr[i:i+window].mean()
return result
def moving_average_vectorized(arr, window):
# Using convolution
return np.convolve(arr, np.ones(window)/window, mode='valid')
Advanced Vectorization
# Conditional operations - use np.where instead of if/else
x = np.random.randn(100)
# Bad
result = np.zeros_like(x)
for i in range(len(x)):
result[i] = x[i] if x[i] > 0 else 0
# Good - vectorized ReLU
result = np.where(x > 0, x, 0)
# Even better
result = np.maximum(x, 0)
# Multiple conditions - use np.select
x = np.arange(-5, 6)
conditions = [x < -2, (x >= -2) & (x <= 2), x > 2]
choices = [-1, 0, 1]
result = np.select(conditions, choices)
# Vectorized gradient clipping
gradients = np.random.randn(1000, 100)
clip_value = 1.0
norm = np.linalg.norm(gradients, axis=1, keepdims=True)
gradients = np.where(norm > clip_value,
gradients * clip_value / norm,
gradients)
Reshaping and Transformations
Basic Reshaping
arr = np.arange(12) # [0, 1, 2, ..., 11]
# Reshape - returns view if possible
arr.reshape(3, 4)
# [[ 0 1 2 3]
# [ 4 5 6 7]
# [ 8 9 10 11]]
arr.reshape(2, 6)
arr.reshape(2, 2, 3) # 3D array
# Infer dimension with -1
arr.reshape(3, -1) # NumPy calculates: (3, 4)
arr.reshape(-1, 2) # (6, 2)
arr.reshape(-1) # Flatten to 1D
# Reshape and transpose in one go
arr.reshape(3, 4, order='F') # Fortran-style (column-major)
Flatten vs Ravel vs Reshape
matrix = np.array([[1, 2], [3, 4]])
# flatten() - always returns a copy
flat1 = matrix.flatten()
flat1[0] = 999
# matrix unchanged
# ravel() - returns view if possible (more efficient)
flat2 = matrix.ravel()
flat2[0] = 999
# matrix[0, 0] is now 999!
# reshape(-1) - same as ravel
flat3 = matrix.reshape(-1)
Transposition and Axis Manipulation
# 2D transpose
matrix = np.array([[1, 2, 3], [4, 5, 6]])
matrix.T
# [[1, 4],
# [2, 5],
# [3, 6]]
# Multi-dimensional transpose
tensor = np.random.randn(2, 3, 4)
tensor.transpose(2, 0, 1) # Move axes: (4, 2, 3)
tensor.transpose() # Reverse all axes: (4, 3, 2)
# Swapaxes - swap two specific axes
tensor.swapaxes(0, 2) # Swap first and last: (4, 3, 2)
# moveaxis - more intuitive for single axis moves
tensor_moved = np.moveaxis(tensor, 0, -1) # Move first axis to last: (3, 4, 2)
Dimension Manipulation
arr = np.array([1, 2, 3])
# Add dimensions
arr_col = arr[:, np.newaxis] # Shape (3, 1)
arr_row = arr[np.newaxis, :] # Shape (1, 3)
arr_3d = arr[:, np.newaxis, np.newaxis] # Shape (3, 1, 1)
# Using expand_dims
arr_col = np.expand_dims(arr, axis=1) # Shape (3, 1)
arr_3d = np.expand_dims(arr, axis=(1, 2)) # Shape (3, 1, 1)
# Remove dimensions - squeeze
arr_squeezed = np.squeeze(arr_col) # Back to (3,)
# Broadcast to specific shape
arr_broadcast = np.broadcast_to(arr[:, np.newaxis], (3, 5))
# [[1, 1, 1, 1, 1],
# [2, 2, 2, 2, 2],
# [3, 3, 3, 3, 3]]
Stacking and Splitting
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
# Vertical stack (along axis 0)
np.vstack([a, b])
# [[1, 2, 3],
# [4, 5, 6]]
# Horizontal stack (along axis 1)
np.hstack([a, b])
# [1, 2, 3, 4, 5, 6]
# Stack along new axis
np.stack([a, b], axis=0) # Shape (2, 3)
np.stack([a, b], axis=1) # Shape (3, 2)
# Concatenate - general stacking
np.concatenate([a, b], axis=0)
# Splitting
arr = np.arange(9)
np.split(arr, 3) # [array([0, 1, 2]), array([3, 4, 5]), array([6, 7, 8])]
np.array_split(arr, 4) # Unequal splits allowed
# 2D splitting
matrix = np.random.randn(6, 4)
np.vsplit(matrix, 3) # Split into 3 horizontal slices
np.hsplit(matrix, 2) # Split into 2 vertical slices
Practical ML Reshaping Examples
# Batch flattening for fully connected layer
batch_images = np.random.randn(32, 28, 28, 1) # 32 MNIST images
flattened = batch_images.reshape(32, -1) # (32, 784)
# Channel manipulation (NHWC to NCHW)
nhwc = np.random.randn(10, 224, 224, 3)
nchw = nhwc.transpose(0, 3, 1, 2) # (10, 3, 224, 224)
# Reshape for sequence processing
time_series = np.random.randn(1000, 10) # 1000 timesteps, 10 features
batched = time_series.reshape(-1, 50, 10) # 20 sequences of 50 timesteps
# Tile for data augmentation
pattern = np.array([1, 2, 3])
tiled = np.tile(pattern, 5) # [1, 2, 3, 1, 2, 3, ...]
tiled_2d = np.tile(pattern, (3, 1)) # Repeat as rows
Matrix Operations
Element-wise vs Matrix Operations
A = np.array([[1, 2], [3, 4]])
B = np.array([[5, 6], [7, 8]])
# Element-wise multiplication (Hadamard product)
A * B
# [[ 5, 12],
# [21, 32]]
# Matrix multiplication
A @ B # Python 3.5+ operator
np.dot(A, B)
np.matmul(A, B)
# [[19, 22],
# [43, 50]]
Different Multiplication Operations
# 1D arrays - dot product
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
np.dot(a, b) # 1*4 + 2*5 + 3*6 = 32
# 2D matrix multiplication
A = np.random.randn(3, 4)
B = np.random.randn(4, 5)
C = A @ B # Shape (3, 5)
# Batch matrix multiplication
batch_A = np.random.randn(10, 3, 4)
batch_B = np.random.randn(10, 4, 5)
batch_C = batch_A @ batch_B # Shape (10, 3, 5)
# Outer product
a = np.array([1, 2, 3])
b = np.array([4, 5])
np.outer(a, b)
# [[ 4, 5],
# [ 8, 10],
# [12, 15]]
# Inner product (same as dot for 1D)
np.inner(a, b) # Only if same length
# Kronecker product
np.kron(A, B) # Tensor product
Matrix Properties
A = np.array([[1, 2], [3, 4]])
# Trace (sum of diagonal)
np.trace(A) # 1 + 4 = 5
# Determinant
np.linalg.det(A) # -2.0
# Rank
np.linalg.matrix_rank(A) # 2
# Norm
np.linalg.norm(A) # Frobenius norm (default)
np.linalg.norm(A, 'fro') # Frobenius norm
np.linalg.norm(A, 2) # Spectral norm
np.linalg.norm(A, 'nuc') # Nuclear norm
# Condition number
np.linalg.cond(A) # Ratio of largest to smallest singular value
Advanced Matrix Operations
# Matrix power
A = np.array([[1, 2], [3, 4]])
np.linalg.matrix_power(A, 3) # A @ A @ A
# Matrix exponential (important in physics, ODEs)
from scipy.linalg import expm
expm(A)
# Batch operations
batch = np.random.randn(100, 10, 10)
# Batch determinant
dets = np.linalg.det(batch) # Shape (100,)
# Einsum for complex operations (see Advanced Patterns)
# Batch matrix trace
traces = np.einsum('bii->b', batch)
Statistical Operations
Basic Statistics
data = np.random.randn(100, 5) # 100 samples, 5 features
# Central tendency
np.mean(data) # Overall mean
np.median(data) # Median
np.percentile(data, 50) # Same as median
np.quantile(data, 0.5) # Same as median
# Spread
np.std(data) # Standard deviation
np.var(data) # Variance
np.ptp(data) # Peak to peak (max - min)
# Extremes
np.min(data)
np.max(data)
np.argmin(data) # Index of minimum
np.argmax(data) # Index of maximum
Axis-wise Operations
data = np.random.randn(100, 5)
# Along columns (across samples)
feature_means = np.mean(data, axis=0) # Shape (5,)
feature_stds = np.std(data, axis=0)
# Along rows (across features)
sample_means = np.mean(data, axis=1) # Shape (100,)
# Keep dimensions for broadcasting
feature_means = np.mean(data, axis=0, keepdims=True) # Shape (1, 5)
normalized = (data - feature_means) / np.std(data, axis=0, keepdims=True)
# Multiple axes
tensor = np.random.randn(10, 20, 30, 40)
mean_spatial = np.mean(tensor, axis=(1, 2)) # Average over dimensions 1 and 2
Normalization Techniques
# Z-score normalization (standardization)
def standardize(X, axis=0):
mean = np.mean(X, axis=axis, keepdims=True)
std = np.std(X, axis=axis, keepdims=True)
return (X - mean) / (std + 1e-8)
# Min-max normalization
def min_max_normalize(X, axis=0):
min_val = np.min(X, axis=axis, keepdims=True)
max_val = np.max(X, axis=axis, keepdims=True)
return (X - min_val) / (max_val - min_val + 1e-8)
# L2 normalization (unit vectors)
def l2_normalize(X, axis=1):
norm = np.linalg.norm(X, axis=axis, keepdims=True)
return X / (norm + 1e-8)
# Batch normalization (simplified)
def batch_norm(X, gamma=1, beta=0, epsilon=1e-8):
mean = np.mean(X, axis=0, keepdims=True)
var = np.var(X, axis=0, keepdims=True)
X_norm = (X - mean) / np.sqrt(var + epsilon)
return gamma * X_norm + beta
# Whitening (decorrelation)
def whiten(X):
X_centered = X - np.mean(X, axis=0)
cov = np.cov(X_centered, rowvar=False)
U, S, Vt = np.linalg.svd(cov)
W = U @ np.diag(1.0 / np.sqrt(S + 1e-8)) @ U.T
return X_centered @ W
Statistical Functions
# Cumulative operations
arr = np.array([1, 2, 3, 4, 5])
np.cumsum(arr) # [ 1, 3, 6, 10, 15]
np.cumprod(arr) # [ 1, 2, 6, 24, 120]
# Correlation and covariance
data = np.random.randn(100, 5)
np.corrcoef(data, rowvar=False) # Correlation matrix (5, 5)
np.cov(data, rowvar=False) # Covariance matrix (5, 5)
# Histogram
values, bins = np.histogram(data, bins=10)
values, bins = np.histogram(data, bins='auto') # Automatic binning
# Percentiles and quantiles
np.percentile(data, [25, 50, 75]) # Quartiles
np.quantile(data, [0.25, 0.5, 0.75])
# Binning
digitized = np.digitize(data, bins=[-1, 0, 1]) # Classify into bins
# Weighted statistics
weights = np.random.rand(100)
np.average(data, weights=weights, axis=0)
Linear Algebra
Matrix Decompositions
# Eigenvalue decomposition
A = np.random.randn(5, 5)
A = A + A.T # Make symmetric
eigenvalues, eigenvectors = np.linalg.eig(A)
# For symmetric matrices (faster and more stable)
eigenvalues, eigenvectors = np.linalg.eigh(A)
# Singular Value Decomposition (SVD)
M = np.random.randn(10, 5)
U, S, Vt = np.linalg.svd(M, full_matrices=False)
# M ≈ U @ np.diag(S) @ Vt
# U: (10, 5), S: (5,), Vt: (5, 5)
# QR decomposition
Q, R = np.linalg.qr(M)
# M = Q @ R, Q is orthogonal, R is upper triangular
# Cholesky decomposition (for positive definite matrices)
A = np.random.randn(5, 5)
A = A.T @ A # Make positive definite
L = np.linalg.cholesky(A)
# A = L @ L.T, L is lower triangular
# LU decomposition (requires scipy)
from scipy.linalg import lu
P, L, U = lu(A)
Solving Linear Systems
# Solve Ax = b
A = np.array([[3, 1], [1, 2]])
b = np.array([9, 8])
x = np.linalg.solve(A, b) # x = [2, 3]
# Least squares solution (when system is overdetermined)
# Solve ||Ax - b||^2
A = np.random.randn(100, 5)
b = np.random.randn(100)
x, residuals, rank, s = np.linalg.lstsq(A, b, rcond=None)
# Matrix inverse
A_inv = np.linalg.inv(A)
# But prefer solving instead: x = np.linalg.solve(A, b)
# rather than: x = np.linalg.inv(A) @ b
# Pseudo-inverse (Moore-Penrose)
A = np.random.randn(10, 5)
A_pinv = np.linalg.pinv(A)
Matrix Factorizations for ML
# PCA using SVD
def pca(X, n_components):
# Center the data
X_centered = X - np.mean(X, axis=0)
# SVD
U, S, Vt = np.linalg.svd(X_centered, full_matrices=False)
# Project onto top components
components = Vt[:n_components]
X_pca = X_centered @ components.T
# Explained variance
explained_variance = (S ** 2) / (len(X) - 1)
explained_variance_ratio = explained_variance[:n_components] / explained_variance.sum()
return X_pca, components, explained_variance_ratio
# Low-rank approximation
M = np.random.randn(100, 50)
U, S, Vt = np.linalg.svd(M, full_matrices=False)
k = 10 # Keep top 10 components
M_approx = U[:, :k] @ np.diag(S[:k]) @ Vt[:k, :]
# Power iteration for top eigenvector
def power_iteration(A, num_iterations=100):
v = np.random.randn(A.shape[1])
for _ in range(num_iterations):
v = A @ v
v = v / np.linalg.norm(v)
eigenvalue = v @ A @ v
return eigenvalue, v
Random Number Generation
Modern Random API
# Create generator with seed
rng = np.random.default_rng(42)
# Uniform distributions
rng.random((3, 4)) # [0, 1)
rng.uniform(0, 10, size=(3, 4)) # [0, 10)
rng.integers(0, 100, size=10) # [0, 100)
# Normal/Gaussian
rng.normal(loc=0, scale=1, size=(3, 4))
rng.standard_normal((3, 4)) # mean=0, std=1
# Other distributions
rng.exponential(scale=1.0, size=100)
rng.poisson(lam=5, size=100)
rng.binomial(n=10, p=0.5, size=100)
rng.beta(a=2, b=5, size=100)
rng.gamma(shape=2, scale=1, size=100)
rng.multinomial(n=10, pvals=[0.2, 0.3, 0.5], size=20)
Sampling and Shuffling
rng = np.random.default_rng(42)
# Random choice
data = np.arange(100)
sample = rng.choice(data, size=10, replace=False) # Without replacement
# Weighted sampling
weights = np.array([0.1, 0.2, 0.3, 0.4])
samples = rng.choice(4, size=1000, p=weights)
# Shuffle
arr = np.arange(10)
rng.shuffle(arr) # In-place shuffle
# Permutation (returns shuffled copy)
perm = rng.permutation(arr)
perm_indices = rng.permutation(len(arr))
# Random partitioning for train/test split
indices = rng.permutation(len(data))
train_size = int(0.8 * len(data))
train_indices = indices[:train_size]
test_indices = indices[train_size:]
Reproducibility
# Global seed (legacy, not recommended)
np.random.seed(42)
# Better: use Generator instances
rng1 = np.random.default_rng(42)
rng2 = np.random.default_rng(42)
# rng1 and rng2 produce identical sequences
# Independent streams
from numpy.random import SeedSequence, Generator, PCG64
ss = SeedSequence(12345)
child_seeds = ss.spawn(10) # Create 10 independent streams
streams = [Generator(PCG64(s)) for s in child_seeds]
# Each stream is independent
samples = [stream.random(100) for stream in streams]
# Save and restore state
state = rng.bit_generator.state
# ... later ...
rng.bit_generator.state = state # Restore exact state
Initialization Strategies for Neural Networks
rng = np.random.default_rng(42)
def init_weights(shape, method='xavier', rng=None):
if rng is None:
rng = np.random.default_rng()
n_in, n_out = shape
if method == 'xavier' or method == 'glorot':
# Xavier/Glorot initialization (for tanh, sigmoid)
limit = np.sqrt(6 / (n_in + n_out))
return rng.uniform(-limit, limit, shape)
elif method == 'he':
# He initialization (for ReLU)
std = np.sqrt(2 / n_in)
return rng.normal(0, std, shape)
elif method == 'lecun':
# LeCun initialization
std = np.sqrt(1 / n_in)
return rng.normal(0, std, shape)
elif method == 'orthogonal':
# Orthogonal initialization
flat_shape = (n_in, n_out)
a = rng.normal(0, 1, flat_shape)
u, _, v = np.linalg.svd(a, full_matrices=False)
q = u if u.shape == flat_shape else v
return q
else:
return rng.normal(0, 0.01, shape)
# Dropout mask
def dropout_mask(shape, p=0.5, rng=None):
if rng is None:
rng = np.random.default_rng()
mask = rng.random(shape) > p
return mask / (1 - p) # Inverted dropout
# Data augmentation noise
def add_gaussian_noise(X, std=0.1, rng=None):
if rng is None:
rng = np.random.default_rng()
noise = rng.normal(0, std, X.shape)
return X + noise
Advanced Patterns
Einstein Summation (einsum)
Einstein summation is a compact notation for array operations. It's extremely powerful once you understand it.
# Basics
a = np.arange(6).reshape(2, 3)
b = np.arange(12).reshape(3, 4)
# Matrix multiplication: C[i,k] = sum_j A[i,j] * B[j,k]
c = np.einsum('ij,jk->ik', a, b)
# Same as: a @ b
# Trace: sum_i A[i,i]
A = np.random.randn(5, 5)
trace = np.einsum('ii->', A)
# Same as: np.trace(A)
# Diagonal: D[i] = A[i,i]
diag = np.einsum('ii->i', A)
# Same as: np.diag(A)
# Transpose: B[j,i] = A[i,j]
b = np.einsum('ij->ji', a)
# Same as: a.T
# Batch matrix multiplication
batch_a = np.random.randn(10, 3, 4)
batch_b = np.random.randn(10, 4, 5)
batch_c = np.einsum('bij,bjk->bik', batch_a, batch_b)
# Same as: batch_a @ batch_b
# Dot product: sum_i a[i] * b[i]
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
dot = np.einsum('i,i->', a, b)
# Same as: np.dot(a, b)
# Outer product: C[i,j] = a[i] * b[j]
outer = np.einsum('i,j->ij', a, b)
# Same as: np.outer(a, b)
# Element-wise multiplication and sum: sum_ij A[i,j] * B[i,j]
A = np.random.randn(3, 4)
B = np.random.randn(3, 4)
result = np.einsum('ij,ij->', A, B)
# Same as: np.sum(A * B)
Complex einsum Examples for ML
# Attention mechanism
Q = np.random.randn(10, 8, 64) # batch, query_len, dim
K = np.random.randn(10, 12, 64) # batch, key_len, dim
V = np.random.randn(10, 12, 64) # batch, value_len, dim
# Compute attention scores: scores[b,i,j] = sum_d Q[b,i,d] * K[b,j,d]
scores = np.einsum('bid,bjd->bij', Q, K) / np.sqrt(64)
# Apply attention to values: output[b,i,d] = sum_j scores[b,i,j] * V[b,j,d]
attention_weights = softmax(scores, axis=-1)
output = np.einsum('bij,bjd->bid', attention_weights, V)
# Bilinear operation: y[b] = sum_ij x1[b,i] * W[i,j] * x2[b,j]
x1 = np.random.randn(32, 10)
x2 = np.random.randn(32, 20)
W = np.random.randn(10, 20)
y = np.einsum('bi,ij,bj->b', x1, W, x2)
# Batch trace
batch = np.random.randn(100, 5, 5)
traces = np.einsum('bii->b', batch)
# Frobenius norm squared
frob_sq = np.einsum('ij,ij->', A, A)
Universal Functions (ufuncs)
# Create custom ufunc
def relu_scalar(x):
return max(0, x)
relu = np.frompyfunc(relu_scalar, 1, 1) # 1 input, 1 output
# Note: This is for educational purposes; use np.maximum(x, 0) in practice
# Accumulate methods
arr = np.array([1, 2, 3, 4, 5])
np.add.accumulate(arr) # [1, 3, 6, 10, 15] - cumsum
np.multiply.accumulate(arr) # [1, 2, 6, 24, 120] - cumprod
# Reduce methods
np.add.reduce(arr) # 15 - sum
np.multiply.reduce(arr) # 120 - product
np.maximum.reduce(arr) # 5 - max
# Outer methods
np.add.outer(arr[:3], arr[:3])
# [[2, 3, 4],
# [3, 4, 5],
# [4, 5, 6]]
# At method (in-place operations at indices)
arr = np.array([1, 2, 3, 4, 5])
np.add.at(arr, [0, 2, 4], 10) # arr: [11, 2, 13, 4, 15]
Memory Views and Copies
# View vs copy
arr = np.arange(10)
view = arr[::2] # View - no data copied
copy = arr[::2].copy() # Explicit copy
view[0] = 999
# arr is modified!
copy[0] = 999
# arr is unchanged
# Check if it's a view
view.base is arr # True
copy.base is None # True
# Some operations return views
arr.reshape(2, 5) # View (if possible)
arr.T # View
arr[::2] # View
# Some operations return copies
arr.flatten() # Copy
arr + 1 # Copy
arr[[0, 2, 4]] # Copy (fancy indexing)
# Avoid copies with out parameter
arr = np.random.randn(1000)
result = np.empty_like(arr)
np.sin(arr, out=result) # Compute in-place, no extra memory
# Compound operations
arr = np.random.randn(1000)
arr += 1 # In-place, no copy
arr *= 2 # In-place, no copy
# vs
arr = arr + 1 # Creates new array
Advanced Indexing Patterns
# Multi-dimensional boolean indexing
data = np.random.randn(10, 5)
mask = data > 0
positive_values = data[mask] # 1D array of positive values
# Keep structure with np.where
data_clipped = np.where(data > 0, data, 0) # ReLU
# np.where with conditions
condition = data > 0.5
result = np.where(condition, data * 2, data / 2)
# np.select for multiple conditions
conditions = [
data < -1,
(data >= -1) & (data < 0),
(data >= 0) & (data < 1),
data >= 1
]
choices = [-1, 0, 0, 1]
result = np.select(conditions, choices, default=0)
# np.choose (limited to small number of choices)
indices = np.array([0, 1, 2, 1, 0])
choices = np.array([[1, 2, 3, 4, 5],
[10, 20, 30, 40, 50],
[100, 200, 300, 400, 500]])
result = np.choose(indices, choices) # [1, 20, 300, 40, 5]
# Advanced batch indexing
batch = np.random.randn(32, 10)
indices = np.array([3, 1, 5, ...]) # 32 indices
selected = batch[np.arange(32), indices] # 32 values
# Meshgrid for pairwise operations
x = np.array([1, 2, 3])
y = np.array([10, 20])
X, Y = np.meshgrid(x, y, indexing='ij')
# X: [[1, 1], Y: [[10, 20],
# [2, 2], [10, 20],
# [3, 3]] [10, 20]]
Performance Optimization
Memory Layout
# C-order (row-major) vs Fortran-order (column-major)
arr_c = np.array([[1, 2], [3, 4]], order='C') # Default
arr_f = np.array([[1, 2], [3, 4]], order='F')
# Check memory order
arr_c.flags['C_CONTIGUOUS'] # True
arr_f.flags['F_CONTIGUOUS'] # True
# Performance implication
# Iterating over rows is faster for C-order
# Iterating over columns is faster for F-order
# Use appropriate order for your access pattern
matrix_c = np.random.randn(1000, 1000, order='C')
matrix_f = np.random.randn(1000, 1000, order='F')
# Row-wise operations faster on C-order
row_sums_c = matrix_c.sum(axis=1) # Fast
# Column-wise operations faster on F-order
col_sums_f = matrix_f.sum(axis=0) # Fast
In-place Operations
# Avoid creating intermediate arrays
arr = np.random.randn(1000000)
# Bad - creates temporary arrays
result = (arr + 1) * 2 - 3
# Better - use in-place operations
arr += 1
arr *= 2
arr -= 3
# Use out parameter
arr = np.random.randn(1000)
result = np.empty_like(arr)
np.add(arr, 1, out=result)
np.multiply(result, 2, out=result)
np.subtract(result, 3, out=result)
# Compound operations
np.add(arr, 1, out=arr) # Reuse input array
Avoiding Copies
# Slicing creates views (usually)
arr = np.arange(100)
view = arr[10:20] # No copy
# Advanced indexing creates copies
copy = arr[[1, 5, 10]] # Copy created
# Reshaping returns view if possible
view = arr.reshape(10, 10) # View
copy = arr.reshape(10, 10, order='F') # Copy (change order)
# Check if operation creates copy
original = np.arange(12)
reshaped = original.reshape(3, 4)
reshaped.base is original # True - it's a view
# Explicit copy when needed
independent = arr.copy()
Vectorization for Speed
# Profile your code
import time
# Slow - Python loop
arr = np.random.randn(1000000)
start = time.time()
result = np.zeros_like(arr)
for i in range(len(arr)):
result[i] = arr[i] ** 2 if arr[i] > 0 else 0
print(f"Loop: {time.time() - start:.4f}s")
# Fast - vectorized
start = time.time()
result = np.where(arr > 0, arr ** 2, 0)
print(f"Vectorized: {time.time() - start:.4f}s")
# Typically 50-100x faster
# Use specialized functions
# Bad
result = np.sqrt(np.sum(arr ** 2))
# Good
result = np.linalg.norm(arr) # Optimized implementation
Memory-efficient Operations
# Generator expressions for large data
def process_batches(data, batch_size):
n_batches = len(data) // batch_size
for i in range(n_batches):
yield data[i*batch_size:(i+1)*batch_size]
# Memory-mapped arrays for huge datasets
mmap = np.memmap('large_file.dat', dtype='float32', mode='r', shape=(1000000, 1000))
# Only loads data into memory when accessed
# Delete intermediate results
large_array = np.random.randn(10000, 10000)
result = np.sum(large_array, axis=0)
del large_array # Free memory
# Use smaller dtypes when possible
arr_64 = np.random.randn(1000000).astype(np.float64) # 8 MB
arr_32 = np.random.randn(1000000).astype(np.float32) # 4 MB
arr_16 = np.random.randn(1000000).astype(np.float16) # 2 MB
Numba for Ultimate Speed
from numba import jit, prange
# Accelerate with JIT compilation
@jit(nopython=True)
def compute_pairwise_distances(X):
n = X.shape[0]
distances = np.zeros((n, n))
for i in range(n):
for j in range(i+1, n):
d = 0.0
for k in range(X.shape[1]):
d += (X[i, k] - X[j, k]) ** 2
distances[i, j] = np.sqrt(d)
distances[j, i] = distances[i, j]
return distances
# Parallel execution
@jit(nopython=True, parallel=True)
def parallel_sum_squares(arr):
result = 0.0
for i in prange(len(arr)):
result += arr[i] ** 2
return result
Common ML Patterns
One-Hot Encoding
# Method 1: Using np.eye
labels = np.array([0, 2, 1, 0, 3])
n_classes = 4
one_hot = np.eye(n_classes)[labels]
# [[1, 0, 0, 0],
# [0, 0, 1, 0],
# [0, 1, 0, 0],
# [1, 0, 0, 0],
# [0, 0, 0, 1]]
# Method 2: Manual
def one_hot_encode(labels, n_classes):
one_hot = np.zeros((len(labels), n_classes))
one_hot[np.arange(len(labels)), labels] = 1
return one_hot
# Reverse: one-hot to labels
labels_recovered = np.argmax(one_hot, axis=1)
Train-Test Split
def train_test_split(X, y, test_size=0.2, random_state=None):
rng = np.random.default_rng(random_state)
n = len(X)
indices = rng.permutation(n)
split_idx = int(n * (1 - test_size))
train_idx, test_idx = indices[:split_idx], indices[split_idx:]
return X[train_idx], X[test_idx], y[train_idx], y[test_idx]
# K-fold cross-validation indices
def k_fold_indices(n, k=5, shuffle=True, random_state=None):
indices = np.arange(n)
if shuffle:
rng = np.random.default_rng(random_state)
rng.shuffle(indices)
fold_size = n // k
for i in range(k):
test_idx = indices[i*fold_size:(i+1)*fold_size]
train_idx = np.concatenate([indices[:i*fold_size],
indices[(i+1)*fold_size:]])
yield train_idx, test_idx
Mini-batch Generation
def generate_batches(X, y, batch_size, shuffle=True, random_state=None):
"""Generator for mini-batches"""
n = len(X)
rng = np.random.default_rng(random_state)
if shuffle:
indices = rng.permutation(n)
X, y = X[indices], y[indices]
n_batches = n // batch_size
for i in range(n_batches):
start = i * batch_size
end = start + batch_size
yield X[start:end], y[start:end]
# Last batch (if incomplete)
if n % batch_size != 0:
yield X[n_batches*batch_size:], y[n_batches*batch_size:]
# Usage
for X_batch, y_batch in generate_batches(X_train, y_train, batch_size=32):
# Train on batch
pass
Distance Computations
# Euclidean distance matrix (vectorized)
def euclidean_distances(X, Y=None):
"""
Compute pairwise Euclidean distances
If Y is None, compute distances within X
"""
if Y is None:
Y = X
# ||x - y||^2 = ||x||^2 + ||y||^2 - 2*x·y
X_norm = np.sum(X ** 2, axis=1, keepdims=True) # (n, 1)
Y_norm = np.sum(Y ** 2, axis=1, keepdims=True).T # (1, m)
distances = X_norm + Y_norm - 2 * X @ Y.T
# Handle numerical errors
distances = np.maximum(distances, 0)
return np.sqrt(distances)
# Cosine similarity
def cosine_similarity(X, Y=None):
if Y is None:
Y = X
X_norm = X / np.linalg.norm(X, axis=1, keepdims=True)
Y_norm = Y / np.linalg.norm(Y, axis=1, keepdims=True)
return X_norm @ Y_norm.T
# Manhattan distance
def manhattan_distances(X, Y=None):
if Y is None:
Y = X
return np.sum(np.abs(X[:, np.newaxis] - Y[np.newaxis, :]), axis=2)
Activation Functions
# ReLU
def relu(x):
return np.maximum(0, x)
def relu_derivative(x):
return (x > 0).astype(float)
# Leaky ReLU
def leaky_relu(x, alpha=0.01):
return np.where(x > 0, x, alpha * x)
# Sigmoid
def sigmoid(x):
# Numerically stable
return np.where(x >= 0,
1 / (1 + np.exp(-x)),
np.exp(x) / (1 + np.exp(x)))
def sigmoid_derivative(x):
s = sigmoid(x)
return s * (1 - s)
# Tanh
def tanh(x):
return np.tanh(x)
def tanh_derivative(x):
return 1 - np.tanh(x) ** 2
# Softmax (numerically stable)
def softmax(x, axis=-1):
x_max = np.max(x, axis=axis, keepdims=True)
exp_x = np.exp(x - x_max)
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
Loss Functions
# Mean Squared Error
def mse(y_true, y_pred):
return np.mean((y_true - y_pred) ** 2)
def mse_derivative(y_true, y_pred):
return 2 * (y_pred - y_true) / len(y_true)
# Cross-entropy (numerically stable)
def cross_entropy(y_true, y_pred, epsilon=1e-15):
# Clip predictions to prevent log(0)
y_pred = np.clip(y_pred, epsilon, 1 - epsilon)
return -np.mean(y_true * np.log(y_pred))
# Binary cross-entropy
def binary_cross_entropy(y_true, y_pred, epsilon=1e-15):
y_pred = np.clip(y_pred, epsilon, 1 - epsilon)
return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))
# Categorical cross-entropy
def categorical_cross_entropy(y_true, y_pred, epsilon=1e-15):
y_pred = np.clip(y_pred, epsilon, 1 - epsilon)
return -np.sum(y_true * np.log(y_pred)) / len(y_true)
# Hinge loss (SVM)
def hinge_loss(y_true, y_pred):
return np.mean(np.maximum(0, 1 - y_true * y_pred))
Convolution Operations
# 1D convolution (simple implementation)
def conv1d(x, kernel, stride=1, padding=0):
if padding > 0:
x = np.pad(x, padding, mode='constant')
n = len(x)
k = len(kernel)
output_size = (n - k) // stride + 1
output = np.zeros(output_size)
for i in range(output_size):
start = i * stride
output[i] = np.sum(x[start:start+k] * kernel)
return output
# 2D convolution (simple, unoptimized)
def conv2d(image, kernel, stride=1, padding=0):
if padding > 0:
image = np.pad(image, padding, mode='constant')
h, w = image.shape
kh, kw = kernel.shape
out_h = (h - kh) // stride + 1
out_w = (w - kw) // stride + 1
output = np.zeros((out_h, out_w))
for i in range(out_h):
for j in range(out_w):
r, c = i * stride, j * stride
output[i, j] = np.sum(image[r:r+kh, c:c+kw] * kernel)
return output
# Pooling operations
def max_pool2d(x, pool_size=2, stride=2):
h, w = x.shape
out_h = (h - pool_size) // stride + 1
out_w = (w - pool_size) // stride + 1
output = np.zeros((out_h, out_w))
for i in range(out_h):
for j in range(out_w):
r, c = i * stride, j * stride
output[i, j] = np.max(x[r:r+pool_size, c:c+pool_size])
return output
def avg_pool2d(x, pool_size=2, stride=2):
h, w = x.shape
out_h = (h - pool_size) // stride + 1
out_w = (w - pool_size) // stride + 1
output = np.zeros((out_h, out_w))
for i in range(out_h):
for j in range(out_w):
r, c = i * stride, j * stride
output[i, j] = np.mean(x[r:r+pool_size, c:c+pool_size])
return output
Gradient Checking
def numerical_gradient(f, x, epsilon=1e-5):
"""
Compute numerical gradient using finite differences
Useful for debugging backpropagation
"""
grad = np.zeros_like(x)
it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])
while not it.finished:
idx = it.multi_index
old_value = x[idx]
x[idx] = old_value + epsilon
fx_plus = f(x)
x[idx] = old_value - epsilon
fx_minus = f(x)
grad[idx] = (fx_plus - fx_minus) / (2 * epsilon)
x[idx] = old_value
it.iternext()
return grad
def gradient_check(f, x, analytic_grad, epsilon=1e-5):
"""Check if analytic gradient is correct"""
numerical_grad = numerical_gradient(f, x, epsilon)
# Relative error
numerator = np.linalg.norm(numerical_grad - analytic_grad)
denominator = np.linalg.norm(numerical_grad) + np.linalg.norm(analytic_grad)
rel_error = numerator / (denominator + 1e-8)
print(f"Relative error: {rel_error}")
return rel_error < 1e-5 # Threshold for "correct"
Data Augmentation Helpers
# Image augmentation primitives
def random_flip(image, horizontal=True, p=0.5, rng=None):
if rng is None:
rng = np.random.default_rng()
if rng.random() < p:
axis = 1 if horizontal else 0
return np.flip(image, axis=axis)
return image
def random_rotation_90(image, p=0.5, rng=None):
if rng is None:
rng = np.random.default_rng()
if rng.random() < p:
k = rng.integers(1, 4) # 90, 180, or 270 degrees
return np.rot90(image, k=k)
return image
def random_crop(image, crop_size, rng=None):
if rng is None:
rng = np.random.default_rng()
h, w = image.shape[:2]
ch, cw = crop_size
top = rng.integers(0, h - ch + 1)
left = rng.integers(0, w - cw + 1)
return image[top:top+ch, left:left+cw]
def normalize_image(image, mean, std):
"""Normalize image with mean and std per channel"""
return (image - mean) / std
Summary
NumPy mastery is essential for ML engineering. Key takeaways:
- Vectorization is king: Avoid Python loops, use array operations
- Broadcasting enables elegance: Learn the rules, use them everywhere
- Memory matters: Understand views vs copies, use appropriate dtypes
- Use the right tool: einsum for complex operations, specialized functions when available
- Profile your code: Measure before optimizing
- Build on NumPy conventions: Your code will integrate better with the ecosystem
Next Steps:
- Practice implementing ML algorithms from scratch in NumPy
- Study PyTorch/TensorFlow source code to see NumPy patterns at scale
- Profile your code to identify bottlenecks
- Learn Numba for the last 10x speedup when vectorization isn't enough
Resources:
Quantization
Overview
Quantization is the process of reducing the precision of numerical representations in neural networks, typically converting high-precision floating-point weights and activations to lower-precision formats like integers. This technique is fundamental for deploying machine learning models efficiently on resource-constrained devices and achieving faster inference with minimal accuracy loss.
In modern deep learning, quantization has become essential for:
- Deploying large language models (LLMs) on consumer hardware
- Running neural networks on edge devices (smartphones, IoT)
- Reducing inference costs in production systems
- Enabling real-time applications with strict latency requirements
Fundamentals
Numerical Representations
Neural networks traditionally use floating-point arithmetic:
| Format | Bits | Sign | Exponent | Mantissa | Range | Precision |
|---|---|---|---|---|---|---|
| FP32 | 32 | 1 | 8 | 23 | ±3.4×10³⁸ | ~7 decimal digits |
| FP16 | 16 | 1 | 5 | 10 | ±65,504 | ~3 decimal digits |
| BF16 | 16 | 1 | 8 | 7 | ±3.4×10³⁸ | ~2 decimal digits |
| INT8 | 8 | 1 | - | 7 | -128 to 127 | Discrete |
| INT4 | 4 | 1 | - | 3 | -8 to 7 | Discrete |
Brain Float 16 (BF16): Maintains FP32's range with reduced precision, ideal for training.
Integer Formats: Fixed-point arithmetic, faster on specialized hardware.
Quantization Mathematics
The core quantization operation maps continuous values to discrete levels:
Quantization: q = round(x / scale) + zero_point
Dequantization: x_approx = (q - zero_point) * scale
Parameters:
scale: Scaling factor determining step sizezero_point: Offset for asymmetric quantizationq: Quantized integer valuex: Original floating-point value
Symmetric Quantization
Zero-point is 0, simplifying computation:
scale = max(|x_max|, |x_min|) / (2^(b-1) - 1)
q = round(x / scale)
For INT8: scale = max(|x_max|, |x_min|) / 127
Example:
import numpy as np
def symmetric_quantize(x, num_bits=8):
"""Symmetric quantization"""
qmax = 2**(num_bits - 1) - 1 # 127 for INT8
scale = np.max(np.abs(x)) / qmax
q = np.round(x / scale).astype(np.int8)
return q, scale
# Example
x = np.array([1.5, -2.3, 0.5, 3.1])
q, scale = symmetric_quantize(x)
print(f"Original: {x}")
print(f"Quantized: {q}")
print(f"Scale: {scale}")
# Dequantize
x_dequant = q * scale
print(f"Dequantized: {x_dequant}")
print(f"Error: {np.abs(x - x_dequant)}")
Asymmetric Quantization
Uses both scale and zero-point for full range utilization:
scale = (x_max - x_min) / (2^b - 1)
zero_point = round(-x_min / scale)
q = round(x / scale) + zero_point
For UINT8: Full range [0, 255] is utilized.
Example:
def asymmetric_quantize(x, num_bits=8):
"""Asymmetric quantization"""
qmin = 0
qmax = 2**num_bits - 1 # 255 for UINT8
x_min, x_max = x.min(), x.max()
scale = (x_max - x_min) / (qmax - qmin)
zero_point = qmin - round(x_min / scale)
q = np.round(x / scale + zero_point)
q = np.clip(q, qmin, qmax).astype(np.uint8)
return q, scale, zero_point
# Example with positive-only activations (ReLU output)
x = np.array([0.2, 1.5, 0.8, 3.1])
q, scale, zp = asymmetric_quantize(x)
print(f"Original: {x}")
print(f"Quantized: {q}")
print(f"Scale: {scale}, Zero-point: {zp}")
# Dequantize
x_dequant = (q - zp) * scale
print(f"Dequantized: {x_dequant}")
Why Quantization?
Model Size Reduction
Quantization directly reduces model size by using fewer bits per parameter:
| Precision | Memory per Parameter | 7B Model Size | Reduction |
|---|---|---|---|
| FP32 | 4 bytes | 28 GB | Baseline |
| FP16 | 2 bytes | 14 GB | 2× |
| INT8 | 1 byte | 7 GB | 4× |
| INT4 | 0.5 bytes | 3.5 GB | 8× |
Example: LLaMA-7B model:
- FP32: ~28 GB (unusable on consumer GPUs)
- INT8: ~7 GB (fits on RTX 3090)
- INT4: ~3.5 GB (runs on MacBook Pro)
Inference Speed Improvement
Integer operations are significantly faster than floating-point:
| Operation | NVIDIA A100 Throughput | Speedup |
|---|---|---|
| FP32 | 19.5 TFLOPS | 1× |
| FP16 (Tensor Core) | 312 TFLOPS | 16× |
| INT8 (Tensor Core) | 624 TOPS | 32× |
Memory Bandwidth: Moving data is often the bottleneck
- INT8 requires 4× less memory bandwidth than FP32
- Critical for large models where compute is memory-bound
Energy Efficiency
Lower precision = lower energy consumption:
| Operation | Energy (pJ) | Relative |
|---|---|---|
| INT8 ADD | 0.03 | 1× |
| FP16 ADD | 0.4 | 13× |
| FP32 ADD | 0.9 | 30× |
| FP32 MULT | 3.7 | 123× |
Essential for:
- Mobile devices (battery life)
- Edge computing (power constraints)
- Data centers (operational costs)
Edge Deployment
Many edge devices only support integer operations:
- ARM Cortex-M processors
- Google Edge TPU
- Qualcomm Hexagon DSP
- Apple Neural Engine
Quantization enables running sophisticated models on these devices.
Types of Quantization
Post-Training Quantization (PTQ)
Quantize a pre-trained model without retraining. Fast but may lose accuracy.
Dynamic Quantization
Quantizes weights statically, activations dynamically at runtime.
Characteristics:
- Weights: Quantized and stored as INT8
- Activations: Quantized on-the-fly during inference
- No calibration data needed
- Best for memory-bound models (LSTMs, Transformers)
PyTorch Example:
import torch
import torch.quantization
# Original model
model = MyTransformer()
model.eval()
# Dynamic quantization
quantized_model = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear, torch.nn.LSTM}, # Layers to quantize
dtype=torch.qint8
)
# Inference
with torch.no_grad():
output = quantized_model(input_tensor)
# Check size reduction
original_size = sum(p.numel() * p.element_size() for p in model.parameters())
quantized_size = sum(p.numel() * p.element_size() for p in quantized_model.parameters())
print(f"Size reduction: {original_size / quantized_size:.2f}×")
When to use:
- Quick deployment without accuracy loss
- LSTM/Transformer models
- When activation distribution changes per input
Static Quantization
Quantizes both weights and activations using calibration data.
Characteristics:
- Weights: Pre-quantized to INT8
- Activations: Pre-computed scale/zero-point from calibration
- Requires representative calibration dataset
- Best for convolutional networks
- Maximum performance gain
PyTorch Example:
import torch
import torch.quantization
# Prepare model for quantization
model = MyConvNet()
model.eval()
# Specify quantization configuration
model.qconfig = torch.quantization.get_default_qconfig('fbgemm') # x86 CPUs
# Fuse operations (Conv + BatchNorm + ReLU)
torch.quantization.fuse_modules(model, [['conv', 'bn', 'relu']], inplace=True)
# Prepare for static quantization
torch.quantization.prepare(model, inplace=True)
# Calibration: Run representative data through model
with torch.no_grad():
for batch in calibration_data_loader:
model(batch)
# Convert to quantized model
torch.quantization.convert(model, inplace=True)
# Save quantized model
torch.save(model.state_dict(), 'quantized_model.pth')
# Inference
with torch.no_grad():
output = model(input_tensor)
Calibration Best Practices:
def calibrate_model(model, data_loader, num_batches=100):
"""
Calibrate quantization parameters
"""
model.eval()
with torch.no_grad():
for i, (images, _) in enumerate(data_loader):
if i >= num_batches:
break
model(images)
return model
# Use diverse calibration data
# 100-1000 samples usually sufficient
calibrated_model = calibrate_model(prepared_model, val_loader, num_batches=200)
Quantization-Aware Training (QAT)
Simulates quantization during training to maintain accuracy.
Characteristics:
- Fake quantization in forward pass
- Full precision gradients in backward pass
- Highest accuracy for aggressive quantization
- Requires training time and data
How it works:
- Forward pass: Apply quantization (fake quant nodes)
- Compute loss with quantized values
- Backward pass: Use straight-through estimators
- Update weights in full precision
PyTorch Example:
import torch
import torch.quantization
# Start with pre-trained model
model = MyModel()
model.load_state_dict(torch.load('pretrained.pth'))
# Set to training mode
model.train()
# Configure QAT
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
# Prepare for QAT
torch.quantization.prepare_qat(model, inplace=True)
# Fine-tune with quantization simulation
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
criterion = torch.nn.CrossEntropyLoss()
num_epochs = 5 # Fine-tuning epochs
for epoch in range(num_epochs):
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f"Epoch {epoch}: Loss = {loss.item():.4f}")
# Convert to fully quantized model
model.eval()
torch.quantization.convert(model, inplace=True)
# Evaluate
accuracy = evaluate(model, test_loader)
print(f"Quantized model accuracy: {accuracy:.2f}%")
Fake Quantization:
class FakeQuantize(torch.nn.Module):
"""Simulates quantization effects during training"""
def __init__(self, num_bits=8):
super().__init__()
self.num_bits = num_bits
self.qmin = 0
self.qmax = 2**num_bits - 1
self.scale = torch.nn.Parameter(torch.ones(1))
self.zero_point = torch.nn.Parameter(torch.zeros(1))
def forward(self, x):
# Quantize
q = torch.clamp(
torch.round(x / self.scale + self.zero_point),
self.qmin, self.qmax
)
# Dequantize
x_fake_quant = (q - self.zero_point) * self.scale
return x_fake_quant
Quantization Granularity
Per-Tensor Quantization
Single scale/zero-point for entire tensor.
Advantages:
- Simpler implementation
- Faster computation
- Lower memory overhead
Disadvantages:
- Less accurate for tensors with wide value ranges
- Outliers affect entire tensor
def per_tensor_quantize(tensor, num_bits=8):
"""Quantize entire tensor with single scale"""
qmin, qmax = 0, 2**num_bits - 1
min_val, max_val = tensor.min(), tensor.max()
scale = (max_val - min_val) / (qmax - qmin)
zero_point = qmin - torch.round(min_val / scale)
q = torch.clamp(
torch.round(tensor / scale + zero_point),
qmin, qmax
)
return q, scale, zero_point
Per-Channel Quantization
Different scale/zero-point per output channel.
Advantages:
- Higher accuracy, especially for convolutional layers
- Handles per-channel variance better
Disadvantages:
- More complex
- Requires hardware support
Applied to: Weights (not activations, due to hardware constraints)
def per_channel_quantize(weight, num_bits=8):
"""
Quantize per output channel (conv filters)
weight shape: [out_channels, in_channels, kernel_h, kernel_w]
"""
out_channels = weight.shape[0]
qmin, qmax = -(2**(num_bits-1)), 2**(num_bits-1) - 1
scales = []
zero_points = []
q_weight = torch.zeros_like(weight, dtype=torch.int8)
for ch in range(out_channels):
ch_weight = weight[ch]
ch_min, ch_max = ch_weight.min(), ch_weight.max()
# Symmetric quantization per channel
scale = max(abs(ch_min), abs(ch_max)) / qmax
scales.append(scale)
zero_points.append(0)
q_weight[ch] = torch.clamp(
torch.round(ch_weight / scale),
qmin, qmax
).to(torch.int8)
return q_weight, torch.tensor(scales), torch.tensor(zero_points)
# Example
conv_weight = torch.randn(64, 3, 3, 3) # 64 filters
q_weight, scales, zps = per_channel_quantize(conv_weight)
print(f"Original shape: {conv_weight.shape}")
print(f"Quantized shape: {q_weight.shape}")
print(f"Scales per channel: {scales.shape}")
Group Quantization
Quantize groups of channels together (compromise between per-tensor and per-channel).
def group_quantize(weight, group_size=4, num_bits=4):
"""Group quantization for weights"""
out_channels = weight.shape[0]
num_groups = (out_channels + group_size - 1) // group_size
scales = []
q_weight = torch.zeros_like(weight, dtype=torch.int8)
for g in range(num_groups):
start = g * group_size
end = min(start + group_size, out_channels)
group_weight = weight[start:end]
scale = group_weight.abs().max() / (2**(num_bits-1) - 1)
scales.append(scale)
q_weight[start:end] = torch.round(group_weight / scale)
return q_weight, torch.tensor(scales)
Advanced Quantization Techniques
Mixed Precision Quantization
Use different precision for different layers based on sensitivity.
Strategy:
- Profile layer sensitivity to quantization
- Keep sensitive layers in higher precision
- Aggressively quantize insensitive layers
def quantize_mixed_precision(model, sensitivity_dict):
"""
Apply different quantization based on layer sensitivity
sensitivity_dict: {layer_name: num_bits}
"""
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if name in sensitivity_dict:
bits = sensitivity_dict[name]
if bits == 8:
# Standard INT8 quantization
quantize_layer(module, num_bits=8)
elif bits == 4:
# Aggressive INT4 quantization
quantize_layer(module, num_bits=4)
else:
# Keep in FP16
module.half()
# Example sensitivity analysis
def analyze_sensitivity(model, data_loader):
"""Measure accuracy drop per layer"""
baseline_acc = evaluate(model, data_loader)
sensitivity = {}
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
# Temporarily quantize this layer
original_weight = module.weight.data.clone()
module.weight.data = quantize_dequantize(original_weight, num_bits=8)
acc = evaluate(model, data_loader)
sensitivity[name] = baseline_acc - acc
# Restore
module.weight.data = original_weight
return sensitivity
GPTQ (GPT Quantization)
Advanced post-training quantization for large language models using layer-wise quantization with Hessian information.
Key Idea: Minimize reconstruction error layer-by-layer using second-order information.
# Using auto-gptq library
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
# Configure GPTQ
quantize_config = BaseQuantizeConfig(
bits=4, # INT4 quantization
group_size=128, # Group size for quantization
desc_act=False, # Activation order
)
# Load model
model = AutoGPTQForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
quantize_config=quantize_config
)
# Prepare calibration data
from datasets import load_dataset
calibration_data = load_dataset("c4", split="train[:1000]")
def prepare_calibration(examples):
return tokenizer(examples["text"], truncation=True, max_length=512)
calibration_dataset = calibration_data.map(prepare_calibration)
# Quantize
model.quantize(calibration_dataset)
# Save quantized model
model.save_quantized("./llama-7b-gptq-4bit")
# Load and use
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
quantized_model = AutoGPTQForCausalLM.from_quantized("./llama-7b-gptq-4bit")
# Generate
input_ids = tokenizer("Once upon a time", return_tensors="pt").input_ids
output = quantized_model.generate(input_ids, max_length=100)
print(tokenizer.decode(output[0]))
GPTQ Algorithm:
- Process model layer-by-layer
- For each layer, use Hessian matrix to determine optimal quantization
- Update weights to minimize reconstruction error
- Use Cholesky decomposition for efficient computation
AWQ (Activation-aware Weight Quantization)
Protects weights corresponding to important activations.
Key Insight: Not all weights are equally important. Weights that multiply with large activations are more critical.
# Using AutoAWQ library
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
# Load model
model = AutoAWQForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# Quantize
quant_config = {
"zero_point": True,
"q_group_size": 128,
"w_bit": 4,
"version": "GEMM"
}
model.quantize(
tokenizer,
quant_config=quant_config,
calib_data="pileval" # Calibration dataset
)
# Save
model.save_quantized("./llama-7b-awq-4bit")
# Load and inference
from awq import AutoAWQForCausalLM
model = AutoAWQForCausalLM.from_quantized("./llama-7b-awq-4bit", fuse_layers=True)
AWQ Method:
- Observe activation distributions
- Scale weights based on activation magnitudes
- Quantize scaled weights
- Adjust scales to maintain equivalence
SmoothQuant
Migrates quantization difficulty from activations to weights.
Problem: Activations often have larger outliers than weights, making them harder to quantize.
Solution: Apply mathematically equivalent transformations to smooth activations.
def smooth_quant(weight, activation, alpha=0.5):
"""
SmoothQuant transformation
Y = (Xdiag(s)^(-1)) · (diag(s)W) = X · W
where s = max(|X|)^α / max(|W|)^(1-α)
"""
# Calculate smoothing scales
activation_absmax = activation.abs().max(dim=0).values
weight_absmax = weight.abs().max(dim=0).values
scales = (activation_absmax ** alpha) / (weight_absmax ** (1 - alpha))
# Apply smoothing
smoothed_weight = weight * scales.unsqueeze(0)
smoothed_activation = activation / scales.unsqueeze(0)
return smoothed_weight, smoothed_activation, scales
# Integration with quantization
class SmoothQuantLinear(torch.nn.Module):
def __init__(self, linear_layer, alpha=0.5):
super().__init__()
self.alpha = alpha
self.scales = None
self.quantized_weight = None
def calibrate(self, activations):
"""Calibrate smoothing scales"""
self.scales = calculate_smooth_scales(
self.weight, activations, self.alpha
)
smoothed_weight = self.weight * self.scales
self.quantized_weight = quantize(smoothed_weight)
def forward(self, x):
smoothed_x = x / self.scales
return F.linear(smoothed_x, self.quantized_weight)
LLM.int8()
Decomposes matrix multiplication into INT8 and FP16 components.
Key Idea: Most values can be quantized to INT8, but rare outliers are kept in FP16.
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
# Configure LLM.int8()
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0, # Outlier threshold
llm_int8_has_fp16_weight=False
)
# Load model with INT8 quantization
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
quantization_config=quantization_config,
device_map="auto"
)
# Model automatically uses INT8 for most operations
# Outliers are processed in FP16
output = model.generate(input_ids, max_length=100)
How it works:
- Identify outlier features (magnitude > threshold)
- Separate into two matrix multiplications:
- Regular features: INT8 × INT8
- Outlier features: FP16 × FP16
- Combine results
4-bit Quantization with NormalFloat (NF4)
Introduced in QLoRA, optimized for normally distributed weights.
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
# Configure 4-bit quantization
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # NormalFloat 4-bit
bnb_4bit_use_double_quant=True, # Double quantization
bnb_4bit_compute_dtype=torch.bfloat16 # Compute in BF16
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-13b-hf",
quantization_config=nf4_config,
device_map="auto"
)
# Can even fine-tune in 4-bit with LoRA
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
# Train with 4-bit base model + 16-bit LoRA adapters
trainer.train()
NF4 Quantization Bins: Optimized for Gaussian distributions
# NF4 quantization levels (non-uniform)
NF4_LEVELS = [
-1.0, -0.6961928009986877, -0.5250730514526367,
-0.39491748809814453, -0.28444138169288635,
-0.18477343022823334, -0.09105003625154495,
0.0, 0.07958029955625534, 0.16093020141124725,
0.24611230194568634, 0.33791524171829224,
0.44070982933044434, 0.5626170039176941,
0.7229568362236023, 1.0
]
Quantization for Different Architectures
Convolutional Neural Networks (CNNs)
CNNs are relatively robust to quantization due to:
- Spatial redundancy in image data
- Batch normalization stabilization
- ReLU activations (non-negative, easier to quantize)
Best Practices:
def quantize_cnn(model):
"""Quantize CNN model"""
# 1. Fuse operations
torch.quantization.fuse_modules(
model,
[['conv1', 'bn1', 'relu']],
inplace=True
)
# 2. Use per-channel quantization for conv layers
model.qconfig = torch.quantization.QConfig(
activation=torch.quantization.default_observer,
weight=torch.quantization.default_per_channel_weight_observer
)
# 3. First and last layers: keep higher precision or use symmetric
# model.conv1.qconfig = custom_qconfig_fp16
# model.fc.qconfig = custom_qconfig_fp16
return model
# Layer fusion example
model = models.resnet18(pretrained=True)
model.eval()
# Fuse Conv-BN-ReLU
fused_model = torch.quantization.fuse_modules(
model,
[
['conv1', 'bn1', 'relu'],
['layer1.0.conv1', 'layer1.0.bn1', 'layer1.0.relu'],
# ... more layers
]
)
Quantization-friendly Architecture:
class QuantizableMobileNetV2(nn.Module):
"""MobileNetV2 designed for quantization"""
def __init__(self):
super().__init__()
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
# Use quantization-friendly operations
self.features = nn.Sequential(
# Depthwise separable convolutions
nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU6(inplace=True),
# ... more layers
)
self.classifier = nn.Linear(1280, num_classes)
def forward(self, x):
x = self.quant(x) # Quantize input
x = self.features(x)
x = self.classifier(x)
x = self.dequant(x) # Dequantize output
return x
Transformers and Large Language Models
Transformers are more sensitive to quantization due to:
- Attention mechanisms with softmax (outliers)
- Layer normalization
- Large embedding tables
- Accumulated errors over many layers
Challenges:
- Outlier features: Some dimensions have extreme values
- Embedding tables: Large memory footprint
- Attention scores: Sensitive to precision
Solutions:
# 1. Layer-wise quantization sensitivity
def quantize_transformer_selective(model):
"""Selectively quantize transformer components"""
for name, module in model.named_modules():
if 'attention' in name:
# Keep attention in higher precision
module.qconfig = get_qconfig_fp16()
elif 'mlp' in name or 'feed_forward' in name:
# Aggressively quantize feed-forward
module.qconfig = get_qconfig_int8()
elif 'layernorm' in name:
# Keep normalization in FP16
module.qconfig = None
# 2. Quantize with outlier handling
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
load_in_8bit=True, # Uses LLM.int8()
device_map="auto",
max_memory={0: "20GB", "cpu": "30GB"}
)
# 3. K-V cache quantization for faster inference
class QuantizedAttention(nn.Module):
"""Attention with quantized K-V cache"""
def __init__(self, config):
super().__init__()
self.config = config
self.kv_bits = 8 # Quantize cached keys/values
def forward(self, hidden_states, past_key_value=None):
# Compute Q, K, V
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
# Quantize K, V for caching
if self.training:
# During training, use FP
past_key_value = (key, value)
else:
# During inference, quantize K-V cache
key_q, key_scale = quantize_tensor(key, self.kv_bits)
value_q, value_scale = quantize_tensor(value, self.kv_bits)
past_key_value = (key_q, key_scale, value_q, value_scale)
# Attention computation...
return output, past_key_value
GPTQ for LLMs:
# Comprehensive GPTQ quantization
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
quantize_config = BaseQuantizeConfig(
bits=4,
group_size=128,
damp_percent=0.01,
desc_act=True, # Better accuracy
sym=False, # Asymmetric quantization
true_sequential=True, # Sequential quantization
model_name_or_path=None,
model_file_base_name="model"
)
# Quantize
model.quantize(
examples=calibration_data,
batch_size=1,
use_triton=True, # Faster with Triton kernels
autotune_warmup_after_quantized=True
)
Vision Transformers (ViT)
Combine challenges of both CNNs and Transformers:
def quantize_vit(model, quantize_attention=False):
"""Quantize Vision Transformer"""
for name, module in model.named_modules():
if 'patch_embed' in name:
# Patch embedding: keep higher precision
module.qconfig = get_qconfig_fp16()
elif 'attn' in name and not quantize_attention:
# Attention: conditional quantization
module.qconfig = None
elif 'mlp' in name:
# MLP blocks: aggressive INT8
module.qconfig = get_qconfig_int8()
return model
# PTQ for ViT
def ptq_vision_transformer(model, calibration_loader):
model.eval()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# Selectively quantize
quantize_vit(model, quantize_attention=False)
# Prepare
torch.quantization.prepare(model, inplace=True)
# Calibrate with image data
with torch.no_grad():
for images, _ in calibration_loader:
model(images)
# Convert
torch.quantization.convert(model, inplace=True)
return model
Recurrent Neural Networks (RNNs/LSTMs)
RNNs benefit significantly from dynamic quantization:
# Dynamic quantization for LSTM
model = nn.LSTM(input_size=256, hidden_size=512, num_layers=2)
quantized_model = torch.quantization.quantize_dynamic(
model,
{nn.LSTM, nn.Linear},
dtype=torch.qint8
)
# For static quantization of RNNs (more complex)
class QuantizableLSTM(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x, hidden=None):
x = self.quant(x)
output, hidden = self.lstm(x, hidden)
output = self.dequant(output)
return output, hidden
Practical Implementation Examples
Example 1: Quantizing ResNet for Image Classification
import torch
import torchvision.models as models
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
# 1. Load pre-trained model
model = models.resnet50(pretrained=True)
model.eval()
# 2. Prepare data
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
calibration_dataset = datasets.ImageFolder('imagenet/val', transform=transform)
calibration_loader = DataLoader(
calibration_dataset,
batch_size=32,
shuffle=True,
num_workers=4
)
# 3. Fuse modules
model.fuse_model() # Fuse Conv-BN-ReLU
# 4. Set quantization config
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# 5. Prepare for calibration
torch.quantization.prepare(model, inplace=True)
# 6. Calibrate
print("Calibrating...")
num_calibration_batches = 100
with torch.no_grad():
for i, (images, _) in enumerate(calibration_loader):
if i >= num_calibration_batches:
break
model(images)
if (i + 1) % 10 == 0:
print(f"Calibrated {i + 1} batches")
# 7. Convert to quantized model
torch.quantization.convert(model, inplace=True)
# 8. Evaluate
def evaluate(model, data_loader, num_batches=None):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for i, (images, labels) in enumerate(data_loader):
if num_batches and i >= num_batches:
break
outputs = model(images)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
return 100. * correct / total
print("Evaluating quantized model...")
accuracy = evaluate(model, calibration_loader, num_batches=200)
print(f"Quantized model accuracy: {accuracy:.2f}%")
# 9. Save quantized model
torch.save(model.state_dict(), 'resnet50_quantized.pth')
# 10. Compare model sizes
def print_model_size(model, label):
torch.save(model.state_dict(), "temp.pth")
size_mb = os.path.getsize("temp.pth") / 1e6
print(f"{label}: {size_mb:.2f} MB")
os.remove("temp.pth")
original_model = models.resnet50(pretrained=True)
print_model_size(original_model, "Original FP32")
print_model_size(model, "Quantized INT8")
Example 2: QAT for Custom Model
import torch
import torch.nn as nn
import torch.quantization
class CustomModel(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.quant = torch.quantization.QuantStub()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.relu2 = nn.ReLU()
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(64, num_classes)
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.pool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
x = self.dequant(x)
return x
def fuse_model(self):
torch.quantization.fuse_modules(
self,
[['conv1', 'bn1', 'relu1'],
['conv2', 'bn2', 'relu2']],
inplace=True
)
# 1. Train FP32 model first
model = CustomModel(num_classes=10)
# ... training code ...
torch.save(model.state_dict(), 'model_fp32.pth')
# 2. Prepare for QAT
model.load_state_dict(torch.load('model_fp32.pth'))
model.train()
# Fuse layers
model.fuse_model()
# Set QAT config
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
# Prepare QAT
torch.quantization.prepare_qat(model, inplace=True)
# 3. Fine-tune with QAT
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
num_epochs = 3
for epoch in range(num_epochs):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
# Validation
model.eval()
val_acc = evaluate(model, val_loader)
print(f'Epoch {epoch}, Validation Accuracy: {val_acc:.2f}%')
# 4. Convert to fully quantized model
model.eval()
torch.quantization.convert(model, inplace=True)
# 5. Final evaluation
test_acc = evaluate(model, test_loader)
print(f'Quantized model test accuracy: {test_acc:.2f}%')
# 6. Save
torch.save(model.state_dict(), 'model_qat_int8.pth')
Example 3: Quantizing BERT for NLP
from transformers import BertForSequenceClassification, BertTokenizer
import torch
# 1. Load model
model_name = "bert-base-uncased"
model = BertForSequenceClassification.from_pretrained(
model_name,
num_labels=2
)
tokenizer = BertTokenizer.from_pretrained(model_name)
# 2. Dynamic quantization (easiest for transformers)
quantized_model = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear}, # Quantize linear layers
dtype=torch.qint8
)
# 3. Test inference
text = "This movie was fantastic! I loved every minute of it."
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
# Original model
output_fp32 = model(**inputs)
# Quantized model
output_int8 = quantized_model(**inputs)
print("FP32 logits:", output_fp32.logits)
print("INT8 logits:", output_int8.logits)
# 4. Compare sizes
def get_model_size(model):
torch.save(model.state_dict(), "temp.pth")
size = os.path.getsize("temp.pth") / 1e6
os.remove("temp.pth")
return size
fp32_size = get_model_size(model)
int8_size = get_model_size(quantized_model)
print(f"FP32 model: {fp32_size:.2f} MB")
print(f"INT8 model: {int8_size:.2f} MB")
print(f"Compression ratio: {fp32_size / int8_size:.2f}×")
# 5. Benchmark inference speed
import time
def benchmark(model, inputs, num_runs=100):
# Warmup
for _ in range(10):
model(**inputs)
start = time.time()
for _ in range(num_runs):
with torch.no_grad():
model(**inputs)
end = time.time()
return (end - start) / num_runs
fp32_time = benchmark(model, inputs)
int8_time = benchmark(quantized_model, inputs)
print(f"FP32 inference: {fp32_time*1000:.2f} ms")
print(f"INT8 inference: {int8_time*1000:.2f} ms")
print(f"Speedup: {fp32_time / int8_time:.2f}×")
Example 4: 4-bit LLM Quantization with bitsandbytes
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
# 1. Configure 4-bit quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True, # Nested quantization
bnb_4bit_quant_type="nf4", # NormalFloat4
bnb_4bit_compute_dtype=torch.bfloat16 # Compute dtype
)
# 2. Load model in 4-bit
model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto", # Automatically distribute across GPUs
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# 3. Generate text
prompt = "Explain quantum computing in simple terms:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
top_p=0.9,
do_sample=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
# 4. Memory usage
print(f"Model memory footprint: {model.get_memory_footprint() / 1e9:.2f} GB")
# 5. Can even fine-tune with QLoRA
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
# Prepare for k-bit training
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
# Add LoRA adapters
config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)
print(f"Trainable parameters: {model.print_trainable_parameters()}")
# Now you can fine-tune with standard training loop
# Only LoRA adapters are trained (in FP32/BF16)
# Base model stays in 4-bit
Performance Analysis and Benchmarking
Measuring Quantization Impact
import torch
import time
import numpy as np
from sklearn.metrics import accuracy_score
class QuantizationBenchmark:
"""Comprehensive quantization benchmarking"""
def __init__(self, model_fp32, model_quantized, test_loader):
self.model_fp32 = model_fp32
self.model_quantized = model_quantized
self.test_loader = test_loader
def measure_accuracy(self, model, num_batches=None):
"""Measure model accuracy"""
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for i, (inputs, labels) in enumerate(self.test_loader):
if num_batches and i >= num_batches:
break
outputs = model(inputs)
preds = outputs.argmax(dim=1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
return accuracy_score(all_labels, all_preds) * 100
def measure_latency(self, model, num_runs=100):
"""Measure inference latency"""
model.eval()
# Get a sample batch
sample_input, _ = next(iter(self.test_loader))
# Warmup
with torch.no_grad():
for _ in range(10):
_ = model(sample_input)
# Benchmark
latencies = []
with torch.no_grad():
for _ in range(num_runs):
start = time.perf_counter()
_ = model(sample_input)
end = time.perf_counter()
latencies.append((end - start) * 1000) # ms
return {
'mean': np.mean(latencies),
'std': np.std(latencies),
'p50': np.percentile(latencies, 50),
'p95': np.percentile(latencies, 95),
'p99': np.percentile(latencies, 99)
}
def measure_throughput(self, model, duration=10):
"""Measure throughput (samples/sec)"""
model.eval()
sample_input, _ = next(iter(self.test_loader))
batch_size = sample_input.size(0)
num_batches = 0
start = time.time()
with torch.no_grad():
while time.time() - start < duration:
_ = model(sample_input)
num_batches += 1
elapsed = time.time() - start
throughput = (num_batches * batch_size) / elapsed
return throughput
def measure_model_size(self, model):
"""Measure model size in MB"""
torch.save(model.state_dict(), "temp_model.pth")
size_mb = os.path.getsize("temp_model.pth") / 1e6
os.remove("temp_model.pth")
return size_mb
def run_full_benchmark(self):
"""Run complete benchmark suite"""
print("=" * 60)
print("Quantization Benchmark Results")
print("=" * 60)
# Accuracy
print("\n[1] Accuracy")
fp32_acc = self.measure_accuracy(self.model_fp32)
quant_acc = self.measure_accuracy(self.model_quantized)
print(f" FP32: {fp32_acc:.2f}%")
print(f" Quantized: {quant_acc:.2f}%")
print(f" Drop: {fp32_acc - quant_acc:.2f}%")
# Model Size
print("\n[2] Model Size")
fp32_size = self.measure_model_size(self.model_fp32)
quant_size = self.measure_model_size(self.model_quantized)
print(f" FP32: {fp32_size:.2f} MB")
print(f" Quantized: {quant_size:.2f} MB")
print(f" Reduction: {fp32_size / quant_size:.2f}×")
# Latency
print("\n[3] Latency (ms)")
fp32_latency = self.measure_latency(self.model_fp32)
quant_latency = self.measure_latency(self.model_quantized)
print(f" FP32: {fp32_latency['mean']:.2f} ± {fp32_latency['std']:.2f}")
print(f" Quantized: {quant_latency['mean']:.2f} ± {quant_latency['std']:.2f}")
print(f" Speedup: {fp32_latency['mean'] / quant_latency['mean']:.2f}×")
# Throughput
print("\n[4] Throughput (samples/sec)")
fp32_throughput = self.measure_throughput(self.model_fp32)
quant_throughput = self.measure_throughput(self.model_quantized)
print(f" FP32: {fp32_throughput:.2f}")
print(f" Quantized: {quant_throughput:.2f}")
print(f" Improvement: {quant_throughput / fp32_throughput:.2f}×")
print("\n" + "=" * 60)
return {
'accuracy': {'fp32': fp32_acc, 'quantized': quant_acc},
'size': {'fp32': fp32_size, 'quantized': quant_size},
'latency': {'fp32': fp32_latency, 'quantized': quant_latency},
'throughput': {'fp32': fp32_throughput, 'quantized': quant_throughput}
}
# Usage
benchmark = QuantizationBenchmark(model_fp32, model_int8, test_loader)
results = benchmark.run_full_benchmark()
Profiling Quantization Errors
def analyze_quantization_error(model_fp32, model_quantized, data_loader):
"""Analyze per-layer quantization errors"""
# Hook to capture activations
activations_fp32 = {}
activations_quant = {}
def get_activation(name, storage):
def hook(model, input, output):
storage[name] = output.detach()
return hook
# Register hooks
for name, module in model_fp32.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
module.register_forward_hook(get_activation(name, activations_fp32))
for name, module in model_quantized.named_modules():
if isinstance(module, (nn.quantized.Conv2d, nn.quantized.Linear)):
module.register_forward_hook(get_activation(name, activations_quant))
# Run inference
sample_input, _ = next(iter(data_loader))
with torch.no_grad():
_ = model_fp32(sample_input)
_ = model_quantized(sample_input)
# Compute errors
errors = {}
for name in activations_fp32:
if name in activations_quant:
fp32_act = activations_fp32[name]
quant_act = activations_quant[name].dequantize() if hasattr(
activations_quant[name], 'dequantize'
) else activations_quant[name]
mse = torch.mean((fp32_act - quant_act) ** 2).item()
mae = torch.mean(torch.abs(fp32_act - quant_act)).item()
relative_error = mae / (torch.mean(torch.abs(fp32_act)).item() + 1e-8)
errors[name] = {
'mse': mse,
'mae': mae,
'relative_error': relative_error
}
# Print results
print("\nPer-Layer Quantization Error Analysis:")
print(f"{'Layer':<40} {'MSE':<15} {'MAE':<15} {'Relative Error'}")
print("-" * 80)
for name, err in sorted(errors.items(), key=lambda x: x[1]['relative_error'], reverse=True):
print(f"{name:<40} {err['mse']:<15.6f} {err['mae']:<15.6f} {err['relative_error']:.4f}")
return errors
Common Challenges and Solutions
Challenge 1: Accuracy Degradation
Problem: Quantized model has significantly lower accuracy.
Solutions:
- Use QAT instead of PTQ:
# If PTQ gives poor accuracy, switch to QAT
model.train()
torch.quantization.prepare_qat(model, inplace=True)
# Fine-tune for 3-5 epochs
- Increase calibration data:
# Use more diverse calibration samples
num_calibration_batches = 1000 # Instead of 100
- Mixed precision:
# Keep sensitive layers in higher precision
for name, module in model.named_modules():
if 'attention' in name or name == 'classifier':
module.qconfig = fp16_qconfig
- Per-channel quantization:
# Use per-channel for weights
model.qconfig = torch.quantization.QConfig(
activation=default_observer,
weight=per_channel_weight_observer # More accurate
)
Challenge 2: Outliers in Activations
Problem: Few extreme values dominate quantization range.
Solutions:
- Clip outliers:
class ClippedObserver(torch.quantization.MinMaxObserver):
def __init__(self, percentile=99.9, **kwargs):
super().__init__(**kwargs)
self.percentile = percentile
def forward(self, x_orig):
x = x_orig.detach()
min_val = torch.quantile(x, (100 - self.percentile) / 100)
max_val = torch.quantile(x, self.percentile / 100)
self.min_val = min_val
self.max_val = max_val
return x_orig
- SmoothQuant approach:
# Migrate difficulty from activations to weights
smoothed_weight, smoothed_activation = smooth_quant(
weight, activation, alpha=0.5
)
- Mixed INT8/FP16 (LLM.int8()):
# Process outliers separately in FP16
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0 # Outlier threshold
)
Challenge 3: Batch Normalization Issues
Problem: Batch norm statistics change after quantization.
Solutions:
- Fuse BN with Conv:
# Always fuse before quantization
torch.quantization.fuse_modules(
model,
[['conv', 'bn', 'relu']],
inplace=True
)
- Recalibrate BN:
def recalibrate_bn(model, data_loader, num_batches=100):
"""Recalculate BN statistics after quantization"""
model.train()
with torch.no_grad():
for i, (inputs, _) in enumerate(data_loader):
if i >= num_batches:
break
model(inputs)
model.eval()
return model
Challenge 4: First/Last Layer Sensitivity
Problem: First and last layers are often more sensitive to quantization.
Solution: Keep them in higher precision
def selective_quantization(model):
"""Quantize all layers except first and last"""
# Set default config
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# Override first layer
model.conv1.qconfig = None # Keep FP32
# Override last layer
model.fc.qconfig = None # Keep FP32
return model
Challenge 5: Hardware-Specific Issues
Problem: Quantized model doesn't run efficiently on target hardware.
Solutions:
- Use appropriate backend:
# For x86 CPUs
qconfig = torch.quantization.get_default_qconfig('fbgemm')
# For ARM CPUs
qconfig = torch.quantization.get_default_qconfig('qnnpack')
- Ensure operator support:
# Check if operator is supported
from torch.quantization import get_default_qconfig_propagation_list
supported_ops = get_default_qconfig_propagation_list()
- Use framework-specific quantization:
# For mobile deployment
from torch.utils.mobile_optimizer import optimize_for_mobile
quantized_model = quantize_dynamic(model)
scripted_model = torch.jit.script(quantized_model)
optimized_model = optimize_for_mobile(scripted_model)
Hardware Considerations
CPU Quantization
x86 CPUs (Intel/AMD):
- Use
fbgemmbackend - INT8 via VNNI (Vector Neural Network Instructions) on modern CPUs
- Best for server deployments
# Configure for x86
import torch.backends.quantized as quantized_backends
quantized_backends.engine = 'fbgemm'
qconfig = torch.quantization.get_default_qconfig('fbgemm')
ARM CPUs:
- Use
qnnpackbackend - Optimized for mobile devices
- Supports NEON instructions
# Configure for ARM
torch.backends.quantized.engine = 'qnnpack'
qconfig = torch.quantization.get_default_qconfig('qnnpack')
GPU Quantization
NVIDIA GPUs:
- Tensor Cores support INT8/INT4
- TensorRT for deployment
- Significant speedup for INT8
# Using TensorRT via torch2trt
from torch2trt import torch2trt
# Create quantized model
x = torch.ones((1, 3, 224, 224)).cuda()
model_trt = torch2trt(
model,
[x],
fp16_mode=False,
int8_mode=True,
int8_calib_dataset=calibration_dataset
)
Mobile/Edge Devices
TensorFlow Lite for mobile:
import tensorflow as tf
# Convert to TFLite with quantization
converter = tf.lite.TFLiteConverter.from_saved_model('model')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# Full integer quantization
def representative_dataset():
for data in calibration_data:
yield [data]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()
# Save
with open('model_quantized.tflite', 'wb') as f:
f.write(tflite_model)
ONNX Runtime:
from onnxruntime.quantization import quantize_dynamic, QuantType
model_input = 'model.onnx'
model_output = 'model_quantized.onnx'
quantize_dynamic(
model_input,
model_output,
weight_type=QuantType.QInt8
)
CoreML for iOS:
import coremltools as ct
# Convert PyTorch to CoreML with quantization
traced_model = torch.jit.trace(model, example_input)
coreml_model = ct.convert(
traced_model,
inputs=[ct.TensorType(shape=example_input.shape)],
convert_to="neuralnetwork",
minimum_deployment_target=ct.target.iOS14
)
# Quantize to INT8
model_int8 = ct.quantize_weights(coreml_model, nbits=8)
model_int8.save("model_quantized.mlmodel")
Tools and Libraries
PyTorch Quantization
import torch.quantization
# Built-in, well-integrated with PyTorch ecosystem
# Supports dynamic, static, and QAT
TensorFlow/TFLite
import tensorflow as tf
# Excellent mobile support via TFLite
# Supports post-training and QAT
ONNX Runtime
from onnxruntime.quantization import quantize_dynamic
# Framework-agnostic
# Good for cross-platform deployment
bitsandbytes
import bitsandbytes as bnb
# Specialized for LLMs
# Supports 4-bit, 8-bit quantization
# LLM.int8() and NF4
Auto-GPTQ
from auto_gptq import AutoGPTQForCausalLM
# State-of-the-art LLM quantization
# GPTQ algorithm implementation
AutoAWQ
from awq import AutoAWQForCausalLM
# Activation-aware quantization
# Often better than GPTQ for inference
Intel Neural Compressor
from neural_compressor import Quantization
# Comprehensive quantization toolkit
# Supports multiple frameworks
NVIDIA TensorRT
import tensorrt as trt
# High-performance inference
# INT8/FP16 optimization
Best Practices
-
Start with Dynamic Quantization
- Easiest to implement
- No calibration needed
- Good baseline
-
Calibration Data Quality
- Use representative data
- 100-1000 samples usually sufficient
- Diverse coverage of input distribution
-
Layer-wise Sensitivity Analysis
- Identify sensitive layers
- Keep them in higher precision
- Aggressively quantize insensitive layers
-
Fuse Operations
- Always fuse Conv-BN-ReLU
- Reduces quantization error
- Improves performance
-
Measure Everything
- Accuracy
- Latency
- Throughput
- Model size
- Memory usage
-
Target Hardware Matters
- Use appropriate backend (fbgemm/qnnpack)
- Test on actual deployment hardware
- Profile performance
-
Quantization-Aware Architecture
- Avoid operations that don't quantize well
- Use ReLU6 instead of other activations
- Consider architecture during design
-
Version Control Quantized Models
- Track quantization configs
- Document calibration process
- Maintain reproducibility
Resources and Papers
Foundational Papers
-
"Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference"
- Jacob et al., 2018
- Introduced per-channel quantization and fake quantization
-
"A Survey of Quantization Methods for Efficient Neural Network Inference"
- Gholami et al., 2021
- Comprehensive overview of quantization techniques
-
"LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale"
- Dettmers et al., 2022
- Outlier-aware quantization for LLMs
-
"GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers"
- Frantar et al., 2023
- State-of-the-art PTQ for LLMs
-
"AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration"
- Lin et al., 2023
- Protects salient weights
-
"SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models"
- Xiao et al., 2023
- Smooths activation outliers
-
"QLoRA: Efficient Finetuning of Quantized LLMs"
- Dettmers et al., 2023
- 4-bit quantization with LoRA fine-tuning
Tutorials and Guides
- PyTorch Quantization Documentation
- TensorFlow Lite Quantization Guide
- Hugging Face Quantization Guide
- NVIDIA TensorRT Documentation
Libraries and Tools
- PyTorch:
torch.quantization - TensorFlow:
tf.quantization, TFLite - ONNX Runtime:
onnxruntime.quantization - bitsandbytes:
bitsandbytes - Auto-GPTQ:
auto-gptq - AutoAWQ:
autoawq - Intel Neural Compressor:
neural-compressor
Datasets for Calibration
- ImageNet (computer vision)
- C4, WikiText (language models)
- COCO (object detection)
- Custom domain-specific data (recommended)
Summary
Quantization is an essential technique for deploying neural networks efficiently:
- Reduces model size by 4-8× (INT8, INT4)
- Increases inference speed by 2-4× on appropriate hardware
- Enables edge deployment on resource-constrained devices
- Maintains accuracy with proper techniques (QAT, calibration)
Key Takeaways:
- Choose quantization method based on constraints (time, accuracy, hardware)
- Dynamic quantization: quickest start, good for RNNs/Transformers
- Static quantization: best performance for CNNs
- QAT: highest accuracy for aggressive quantization
- Modern LLMs: GPTQ, AWQ, or bitsandbytes for 4-bit quantization
- Always measure: accuracy, latency, model size, throughput
- Hardware matters: use appropriate backend and test on target device
Quantization transforms impractical models into deployable solutions, making AI accessible on everything from smartphones to data centers.
Interesting Machine Learning Papers
Key papers that shaped the field of machine learning and deep learning.
Table of Contents
- Computer Vision
- Natural Language Processing
- Generative Models
- Reinforcement Learning
- General Machine Learning
- Optimization
Computer Vision
AlexNet (2012)
ImageNet Classification with Deep Convolutional Neural Networks
- Authors: Alex Krizhevsky, Ilya Sutskever, Geoffrey Hinton
- Key Contributions:
- First deep CNN to win ImageNet competition
- Used ReLU activation, dropout, and data augmentation
- GPU training for deep networks
- Reduced error rate from 26% to 15.3%
- Impact: Sparked deep learning revolution
VGGNet (2014)
Very Deep Convolutional Networks for Large-Scale Image Recognition
- Authors: Karen Simonyan, Andrew Zisserman
- Key Contributions:
- Showed depth is crucial (16-19 layers)
- Used small 3x3 filters throughout
- Simple, homogeneous architecture
- Architecture: Stacked 3x3 conv layers, 2x2 max pooling
ResNet (2015)
Deep Residual Learning for Image Recognition
- Authors: Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
- Key Contributions:
- Residual connections solve vanishing gradient problem
- Enabled training of networks with 100+ layers
- Won ImageNet 2015 with 152 layers
- Skip connections: y = F(x) + x
- Impact: Fundamental building block for modern architectures
Vision Transformer (ViT) (2020)
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
- Authors: Alexey Dosovitskiy et al. (Google Research)
- Key Contributions:
- Applied transformers directly to image patches
- Competitive with CNNs on large datasets
- Self-attention for vision tasks
- Architecture:
- Split image into patches
- Linear embedding of patches
- Add position embeddings
- Standard transformer encoder
YOLO (2015)
You Only Look Once: Unified, Real-Time Object Detection
- Authors: Joseph Redmon et al.
- Key Contributions:
- Single-stage object detection
- Real-time performance (45 FPS)
- End-to-end training
- Grid-based prediction
Mask R-CNN (2017)
Mask R-CNN
- Authors: Kaiming He, Georgia Gkioxari, Piotr Dollár, Ross Girshick
- Key Contributions:
- Instance segmentation framework
- Extends Faster R-CNN with mask branch
- Parallel prediction of masks and classes
Natural Language Processing
Word2Vec (2013)
Efficient Estimation of Word Representations in Vector Space
- Authors: Tomas Mikolov et al. (Google)
- Key Contributions:
- Distributed word representations
- Skip-gram and CBOW architectures
- Captures semantic relationships
- king - man + woman ≈ queen
- Impact: Foundation for modern NLP embeddings
Attention Is All You Need (2017)
Attention Is All You Need
- Authors: Ashish Vaswani et al. (Google Brain)
- Key Contributions:
- Introduced Transformer architecture
- Self-attention mechanism
- No recurrence or convolution
- Parallel training
- Architecture:
- Multi-head self-attention
- Position-wise feed-forward networks
- Positional encoding
- Encoder-decoder structure
- Impact: Revolutionized NLP and beyond
BERT (2018)
BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
- Authors: Jacob Devlin et al. (Google AI)
- Key Contributions:
- Bidirectional pre-training
- Masked Language Modeling (MLM)
- Next Sentence Prediction (NSP)
- Transfer learning for NLP
- Pre-training objectives:
- Mask 15% of tokens, predict them
- Predict if sentence B follows A
- Impact: Set new SOTA on 11 NLP tasks
GPT (2018-2023)
Improving Language Understanding by Generative Pre-Training
- GPT-1 (2018): 117M parameters, unsupervised pre-training
- GPT-2 (2019): 1.5B parameters, zero-shot learning
- GPT-3 (2020): 175B parameters, few-shot learning
- GPT-4 (2023): Multimodal, improved reasoning
Key Contributions:
- Autoregressive language modeling
- Scaling laws for language models
- In-context learning
- Emergent capabilities at scale
T5 (2019)
Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
- Authors: Colin Raffel et al. (Google)
- Key Contributions:
- Unified text-to-text framework
- All NLP tasks as text generation
- Comprehensive study of transfer learning
- Format: "translate English to German: text" → "translation"
ELECTRA (2020)
ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators
- Authors: Kevin Clark et al. (Stanford/Google)
- Key Contributions:
- Replaced token detection (RTD)
- More sample-efficient than BERT
- Generator-discriminator framework
- Discriminator predicts which tokens are replaced
Generative Models
GAN (2014)
Generative Adversarial Networks
- Authors: Ian Goodfellow et al.
- Key Contributions:
- Two-player minimax game
- Generator vs Discriminator
- Implicit density modeling
- Objective: min_G max_D V(D,G) = E[log D(x)] + E[log(1-D(G(z)))]
- Impact: New paradigm for generative modeling
DCGAN (2015)
Unsupervised Representation Learning with Deep Convolutional GANs
- Authors: Alec Radford, Luke Metz, Soumith Chintala
- Key Contributions:
- Architectural guidelines for stable GAN training
- All convolutional network
- Batch normalization
- No fully connected layers
- Best practices: Strided convolutions, BatchNorm, LeakyReLU
StyleGAN (2018-2020)
A Style-Based Generator Architecture for GANs
- Authors: Tero Karras et al. (NVIDIA)
- Key Contributions:
- Style-based generator
- Adaptive Instance Normalization (AdaIN)
- Progressive growing
- High-quality face generation
- StyleGAN2 improvements: Weight demodulation, path length regularization
VAE (2013)
Auto-Encoding Variational Bayes
- Authors: Diederik Kingma, Max Welling
- Key Contributions:
- Variational inference for latent variable models
- Reparameterization trick
- ELBO objective
- Probabilistic encoder-decoder
- Objective: Maximize ELBO = E[log p(x|z)] - KL(q(z|x)||p(z))
Diffusion Models (2020)
Denoising Diffusion Probabilistic Models
- Authors: Jonathan Ho, Ajay Jain, Pieter Abbeel
- Key Contributions:
- Iterative denoising process
- High-quality image generation
- Stable training
- Process:
- Forward: Gradually add noise
- Reverse: Learn to denoise
DALL-E 2 (2022)
Hierarchical Text-Conditional Image Generation with CLIP Latents
- Authors: Aditya Ramesh et al. (OpenAI)
- Key Contributions:
- Text-to-image generation
- CLIP guidance
- Prior and decoder models
- Improved image quality and text alignment
Stable Diffusion (2022)
High-Resolution Image Synthesis with Latent Diffusion Models
- Authors: Robin Rombach et al.
- Key Contributions:
- Diffusion in latent space
- More efficient than pixel-space diffusion
- Text-conditional generation
- Open source
Reinforcement Learning
DQN (2013)
Playing Atari with Deep Reinforcement Learning
- Authors: Volodymyr Mnih et al. (DeepMind)
- Key Contributions:
- Deep Q-learning
- Experience replay
- Target network
- End-to-end RL from pixels
- Impact: First deep RL to master Atari games
AlphaGo (2016)
Mastering the game of Go with deep neural networks and tree search
- Authors: David Silver et al. (DeepMind)
- Key Contributions:
- Combined deep learning with Monte Carlo Tree Search
- Policy and value networks
- Self-play training
- Beat world champion Lee Sedol
- AlphaZero (2017): Generalized to chess and shogi
PPO (2017)
Proximal Policy Optimization Algorithms
- Authors: John Schulman et al. (OpenAI)
- Key Contributions:
- Clipped surrogate objective
- Stable policy updates
- Sample efficient
- Easy to implement
- Widely used in practice
MuZero (2019)
Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model
- Authors: Julian Schrittwieser et al. (DeepMind)
- Key Contributions:
- Model-based RL without knowing rules
- Learns dynamics model
- Plans in latent space
- Superhuman performance
Decision Transformer (2021)
Decision Transformer: Reinforcement Learning via Sequence Modeling
- Authors: Lili Chen et al. (Berkeley)
- Key Contributions:
- RL as sequence modeling
- Conditional generation of actions
- Leverages transformer architecture
- Offline RL
General Machine Learning
Dropout (2014)
Dropout: A Simple Way to Prevent Neural Networks from Overfitting
- Authors: Nitish Srivastava et al.
- Key Contributions:
- Randomly drop units during training
- Reduces overfitting
- Ensemble effect
- Simple and effective regularization
Batch Normalization (2015)
Batch Normalization: Accelerating Deep Network Training
- Authors: Sergey Ioffe, Christian Szegedy (Google)
- Key Contributions:
- Normalize layer inputs
- Reduces internal covariate shift
- Enables higher learning rates
- Acts as regularizer
- Operation: Normalize, then scale and shift
Adam Optimizer (2014)
Adam: A Method for Stochastic Optimization
- Authors: Diederik Kingma, Jimmy Ba
- Key Contributions:
- Adaptive learning rates
- Combines momentum and RMSprop
- Bias correction
- Default optimizer for many tasks
Layer Normalization (2016)
Layer Normalization
- Authors: Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey Hinton
- Key Contributions:
- Normalize across features
- Better for RNNs and Transformers
- No batch dependence
ELU (2015)
Fast and Accurate Deep Network Learning by Exponential Linear Units
- Authors: Djork-Arné Clevert et al.
- Key Contributions:
- Negative values push mean towards zero
- Reduces bias shift
- Faster learning
Optimization
SGD with Momentum (1999)
On the momentum term in gradient descent learning algorithms
- Key Contributions:
- Accumulate gradients
- Faster convergence
- Reduces oscillations
RMSprop (2012)
Neural Networks for Machine Learning - Lecture 6
- Author: Geoffrey Hinton
- Key Contributions:
- Adaptive learning rates per parameter
- Divides by running average of gradient magnitudes
Learning Rate Schedules
Cosine Annealing (2016)
- SGDR: Stochastic Gradient Descent with Warm Restarts
- Cosine decay with restarts
- Enables finding multiple local minima
One Cycle Policy (2018)
- Super-Convergence: Very Fast Training of Neural Networks
- Cyclical learning rate with momentum
- Train faster with fewer epochs
Interpretability and Explainability
Grad-CAM (2016)
Grad-CAM: Visual Explanations from Deep Networks
- Authors: Ramprasaath Selvaraju et al.
- Key Contributions:
- Visualize what CNN looks at
- Gradient-weighted class activation mapping
- Works with any CNN architecture
LIME (2016)
"Why Should I Trust You?": Explaining Predictions of Any Classifier
- Authors: Marco Tulio Ribeiro et al.
- Key Contributions:
- Local interpretable model-agnostic explanations
- Approximate complex models locally
- Works with any classifier
SHAP (2017)
A Unified Approach to Interpreting Model Predictions
- Authors: Scott Lundberg, Su-In Lee
- Key Contributions:
- Shapley values for feature importance
- Game-theoretic approach
- Consistent and locally accurate
Efficiency and Compression
MobileNets (2017)
MobileNets: Efficient Convolutional Neural Networks for Mobile Vision
- Authors: Andrew Howard et al. (Google)
- Key Contributions:
- Depthwise separable convolutions
- Width and resolution multipliers
- Efficient for mobile devices
SqueezeNet (2016)
SqueezeNet: AlexNet-level accuracy with 50x fewer parameters
- Authors: Forrest Iandola et al.
- Key Contributions:
- Fire modules (squeeze and expand)
- 50x fewer parameters than AlexNet
- Small model size
Knowledge Distillation (2015)
Distilling the Knowledge in a Neural Network
- Authors: Geoffrey Hinton, Oriol Vinyals, Jeff Dean
- Key Contributions:
- Transfer knowledge from large to small model
- Soft targets from teacher
- Temperature scaling
Pruning (2015)
Learning both Weights and Connections for Efficient Neural Networks
- Authors: Song Han et al.
- Key Contributions:
- Remove unimportant weights
- Magnitude-based pruning
- Reduce model size and computation
Meta-Learning
MAML (2017)
Model-Agnostic Meta-Learning for Fast Adaptation
- Authors: Chelsea Finn, Pieter Abbeel, Sergey Levine
- Key Contributions:
- Learn good initialization
- Fast adaptation to new tasks
- Few-shot learning
- Bi-level optimization
Prototypical Networks (2017)
Prototypical Networks for Few-shot Learning
- Authors: Jake Snell, Kevin Swersky, Richard Zemel
- Key Contributions:
- Learn metric space
- Class prototypes as centroids
- Simple and effective
Self-Supervised Learning
SimCLR (2020)
A Simple Framework for Contrastive Learning of Visual Representations
- Authors: Ting Chen et al. (Google)
- Key Contributions:
- Contrastive learning framework
- Large batch sizes crucial
- Strong data augmentation
- No labels needed
BYOL (2020)
Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning
- Authors: Jean-Bastien Grill et al. (DeepMind)
- Key Contributions:
- No negative pairs needed
- Online and target networks
- Momentum encoder
- State-of-the-art representations
MAE (2021)
Masked Autoencoders Are Scalable Vision Learners
- Authors: Kaiming He et al. (Facebook AI)
- Key Contributions:
- Mask random patches
- Reconstruct missing pixels
- Simple and scalable
- Asymmetric encoder-decoder
Papers to Read
Foundational
- Neural Networks and Deep Learning (Nielsen)
- Deep Learning Book (Goodfellow et al.)
- Pattern Recognition and Machine Learning (Bishop)
Recent Surveys
- Attention mechanisms survey
- Transfer learning survey
- Self-supervised learning survey
- Efficient deep learning survey
Follow These Venues
- NeurIPS, ICML, ICLR (ML conferences)
- CVPR, ICCV, ECCV (Computer Vision)
- ACL, EMNLP, NAACL (NLP)
- AAAI, IJCAI (AI)
Resources
- arXiv.org: Pre-prints of latest research
- Papers with Code: Papers with implementations
- Google Scholar: Citation tracking
- Semantic Scholar: AI-powered search
- Distill.pub: Clear explanations
- Two Minute Papers: Video summaries
Artificial Intelligence (AI) Documentation
A comprehensive guide to modern AI technologies, tools, and best practices.
Overview
This directory contains documentation on various AI topics, focusing on practical applications, implementation guides, and best practices for working with modern AI systems.
Contents
1. Prompt Engineering
Learn the art and science of crafting effective prompts for Large Language Models (LLMs):
- Core principles and techniques
- Prompt patterns and templates
- Chain-of-Thought reasoning
- Few-shot and zero-shot learning
- Advanced strategies for different tasks
2. Generative AI
Comprehensive overview of generative AI models and applications:
- Text generation (GPT, Claude, PaLM)
- Image generation (DALL-E, Midjourney, Stable Diffusion)
- Audio and video synthesis
- Multimodal models
- Real-world applications and use cases
3. Stable Diffusion
Detailed guide to Stable Diffusion for image generation:
- Installation and setup
- Prompt engineering for images
- Parameters and settings
- ControlNet and extensions
- Optimization tips
4. Flux.1
Documentation for Black Forest Labs' Flux.1 model:
- Model variants (Dev, Schnell, Pro)
- Setup and usage
- Comparison with other models
- Advanced techniques
5. Llama Models
Complete guide to Meta's Llama family of models:
- Model architecture and variants
- Installation and setup
- Fine-tuning techniques
- Inference optimization
- Deployment strategies
6. Large Language Models (LLMs)
Comprehensive overview of Large Language Models:
- LLM fundamentals and architecture
- Transformer models and attention mechanisms
- Training and inference
- Prompt engineering techniques
- API usage and best practices
7. ComfyUI
Node-based interface for Stable Diffusion workflows:
- Installation and setup
- Workflow creation
- Custom nodes and extensions
- Advanced generation techniques
- Integration with other tools
8. Fine-Tuning
Model adaptation and customization:
- Fine-tuning strategies and approaches
- Parameter-efficient methods (LoRA, QLoRA)
- Dataset preparation and quality
- Training configuration and optimization
- Evaluation and deployment
Key AI Concepts
Large Language Models (LLMs)
LLMs are neural networks trained on vast amounts of text data to understand and generate human-like text. Key characteristics:
- Scale: Billions to trillions of parameters
- Training: Self-supervised learning on diverse text corpora
- Capabilities: Text generation, reasoning, code writing, translation, etc.
- Examples: GPT-4, Claude, Llama, PaLM, Mistral
Transformer Architecture
The foundation of modern LLMs:
Input → Tokenization → Embedding →
Positional Encoding →
Multi-Head Attention →
Feed Forward →
Layer Norm →
Output
Key components:
- Self-Attention: Allows model to weigh importance of different tokens
- Positional Encoding: Provides sequence order information
- Feed-Forward Networks: Process attention outputs
- Residual Connections: Enable training of deep networks
Diffusion Models
State-of-the-art image generation approach:
- Forward Process: Gradually add noise to images
- Reverse Process: Learn to denoise, generating new images
- Conditioning: Guide generation with text, images, or other inputs
Popular AI Tools & Frameworks
For LLMs
# OpenAI API
pip install openai
# Anthropic Claude
pip install anthropic
# Hugging Face Transformers
pip install transformers torch
# LangChain for LLM applications
pip install langchain langchain-community
# LlamaIndex for RAG
pip install llama-index
For Image Generation
# Stable Diffusion WebUI (AUTOMATIC1111)
git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui
cd stable-diffusion-webui
./webui.sh
# ComfyUI (node-based interface)
git clone https://github.com/comfyanonymous/ComfyUI
cd ComfyUI
pip install -r requirements.txt
# Diffusers library
pip install diffusers transformers accelerate
For Model Training & Fine-tuning
# Hugging Face ecosystem
pip install transformers datasets accelerate peft bitsandbytes
# PyTorch
pip install torch torchvision torchaudio
# DeepSpeed for distributed training
pip install deepspeed
# Axolotl for fine-tuning
git clone https://github.com/OpenAccess-AI-Collective/axolotl
Quick Start Examples
Using OpenAI API
from openai import OpenAI
client = OpenAI(api_key="your-api-key")
response = client.chat.completions.create(
model="gpt-4",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Explain quantum computing in simple terms."}
]
)
print(response.choices[0].message.content)
Using Anthropic Claude
import anthropic
client = anthropic.Anthropic(api_key="your-api-key")
message = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[
{"role": "user", "content": "Write a Python function to calculate Fibonacci numbers."}
]
)
print(message.content[0].text)
Using Hugging Face Transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load model and tokenizer
model_name = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
# Generate text
prompt = "What is the theory of relativity?"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_length=200)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Using Stable Diffusion
from diffusers import StableDiffusionPipeline
import torch
# Load pipeline
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1",
torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
# Generate image
prompt = "a serene mountain landscape at sunset, oil painting style"
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("output.png")
Best Practices
1. Prompt Engineering
- Be specific and clear in your instructions
- Provide context and examples
- Use system prompts to set behavior
- Iterate and refine based on outputs
2. Model Selection
- Choose the right model for your task
- Balance capability vs. cost vs. speed
- Consider fine-tuning for specialized tasks
- Use quantization for resource constraints
3. Safety & Ethics
- Implement content filtering
- Monitor for bias and fairness
- Respect copyright and attribution
- Ensure data privacy and security
4. Performance Optimization
- Use batch processing when possible
- Implement caching for repeated queries
- Optimize prompts for token efficiency
- Use streaming for real-time responses
Resources
Official Documentation
Learning Resources
Community
Contributing
This documentation is continuously updated with new techniques, models, and best practices. Each section contains practical examples and code snippets that you can use immediately.
License
This documentation is provided for educational purposes. Please refer to individual model and tool licenses for usage terms.
Generative AI
A comprehensive guide to generative AI models, applications, and practical implementations.
Table of Contents
- Introduction
- Core Concepts
- Text Generation
- Image Generation
- Audio Generation
- Video Generation
- Multimodal Models
- Applications
- Implementation Examples
Introduction
Generative AI refers to artificial intelligence systems that can create new content—text, images, audio, video, code, and more. Unlike discriminative models that classify or predict, generative models learn to produce novel outputs that resemble their training data.
Key Characteristics
- Content Creation: Generate new, original content
- Pattern Learning: Understand and replicate complex patterns
- Conditional Generation: Create outputs based on specific inputs/prompts
- Iterative Refinement: Improve outputs through multiple passes
Core Concepts
1. Generative Models
Autoregressive Models
Generate sequences one token at a time, using previous tokens as context:
P(x₁, x₂, ..., xₙ) = P(x₁) × P(x₂|x₁) × P(x₃|x₁,x₂) × ... × P(xₙ|x₁,...,xₙ₋₁)
Examples: GPT series, LLaMA
Diffusion Models
Learn to denoise data through iterative refinement:
Forward process: x₀ → x₁ → ... → xₜ (add noise)
Reverse process: xₜ → xₜ₋₁ → ... → x₀ (remove noise)
Examples: Stable Diffusion, DALL-E 3, Midjourney
Variational Autoencoders (VAE)
Learn compressed representations in latent space:
Encoder: x → z (data to latent space)
Decoder: z → x' (latent space to reconstruction)
Generative Adversarial Networks (GAN)
Two networks compete—generator creates, discriminator evaluates:
Generator: z → x (noise to data)
Discriminator: x → [0,1] (real vs fake)
Examples: StyleGAN, BigGAN
2. Foundation Models
Large-scale models trained on vast datasets, adaptable to many tasks:
- Scale: Billions to trillions of parameters
- Transfer Learning: Fine-tune for specific tasks
- Few-Shot Learning: Adapt with minimal examples
- Emergent Abilities: Capabilities not explicitly trained
Text Generation
Large Language Models (LLMs)
GPT Family (OpenAI)
from openai import OpenAI
client = OpenAI(api_key="your-key")
# GPT-4 Turbo - Most capable
response = client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=[
{"role": "system", "content": "You are a creative writer."},
{"role": "user", "content": "Write a short sci-fi story about AI."}
],
temperature=0.8,
max_tokens=500
)
print(response.choices[0].message.content)
Models:
gpt-4-turbo: Most capable, best for complex tasksgpt-4: High capability, slower and more expensivegpt-3.5-turbo: Fast, cost-effective for simple tasks
Claude (Anthropic)
import anthropic
client = anthropic.Anthropic(api_key="your-key")
# Claude Sonnet 4.5 - Latest model
message = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[
{
"role": "user",
"content": "Analyze this code and suggest improvements: [code]"
}
]
)
print(message.content[0].text)
Models:
claude-sonnet-4-5: Balanced performance and capabilityclaude-opus-4: Most capable, deep analysisclaude-haiku-4: Fastest, most cost-effective
Llama (Meta)
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_name = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
# Chat format
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Explain quantum computing."}
]
input_ids = tokenizer.apply_chat_template(
messages,
return_tensors="pt"
).to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=256,
temperature=0.7,
top_p=0.9
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Mistral
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
client = MistralClient(api_key="your-key")
messages = [
ChatMessage(role="user", content="What is machine learning?")
]
# Mistral Large - Most capable
response = client.chat(
model="mistral-large-latest",
messages=messages
)
print(response.choices[0].message.content)
Use Cases for Text Generation
1. Content Creation
# Blog post generation
prompt = """
Write a 500-word blog post about sustainable living.
Include:
- Engaging introduction
- 3 practical tips
- Statistics or facts
- Call to action
Tone: Informative but conversational
"""
2. Code Generation
# Function generation
prompt = """
Create a Python function that:
- Takes a list of dictionaries
- Filters by a key-value pair
- Sorts by another key
- Returns top N results
Include type hints and docstring.
"""
3. Data Analysis
# Analysis prompt
prompt = """
Analyze this sales data and provide:
1. Key trends
2. Anomalies
3. Predictions
4. Recommendations
Data: [CSV or JSON data]
"""
4. Translation
# Contextual translation
prompt = """
Translate this technical documentation from English to Spanish:
[text]
Maintain:
- Technical terminology accuracy
- Professional tone
- Code examples unchanged
"""
Image Generation
Stable Diffusion
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
import torch
# Load model
model_id = "stabilityai/stable-diffusion-2-1"
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
pipe.scheduler.config
)
pipe = pipe.to("cuda")
# Generate image
prompt = "a serene japanese garden with cherry blossoms, koi pond, stone lanterns, soft morning light, highly detailed, 4k"
negative_prompt = "blurry, distorted, low quality, watermark"
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=30,
guidance_scale=7.5,
width=768,
height=768
).images[0]
image.save("japanese_garden.png")
DALL-E 3 (OpenAI)
from openai import OpenAI
client = OpenAI()
response = client.images.generate(
model="dall-e-3",
prompt="A futuristic city with flying cars and neon lights, cyberpunk style, detailed, high quality",
size="1024x1024",
quality="hd",
n=1
)
image_url = response.data[0].url
print(f"Generated image: {image_url}")
Midjourney
Accessed through Discord bot:
/imagine prompt: a mystical forest with glowing mushrooms, ethereal lighting, fantasy art style, intricate details --v 6 --ar 16:9 --q 2
Parameters:
--v: Version (6 is latest)--ar: Aspect ratio--q: Quality (0.25, 0.5, 1, 2)--s: Stylization (0-1000)
Image-to-Image
from diffusers import StableDiffusionImg2ImgPipeline
from PIL import Image
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1",
torch_dtype=torch.float16
).to("cuda")
# Load initial image
init_image = Image.open("sketch.png").convert("RGB")
init_image = init_image.resize((768, 768))
# Transform image
prompt = "a professional photograph of a modern building, architectural photography"
images = pipe(
prompt=prompt,
image=init_image,
strength=0.75, # How much to transform (0=no change, 1=complete regeneration)
guidance_scale=7.5,
num_inference_steps=50
).images
images[0].save("transformed.png")
Inpainting
from diffusers import StableDiffusionInpaintPipeline
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16
).to("cuda")
# Load image and mask
image = Image.open("photo.png")
mask = Image.open("mask.png") # White areas will be regenerated
prompt = "a red sports car"
result = pipe(
prompt=prompt,
image=image,
mask_image=mask,
num_inference_steps=50
).images[0]
result.save("inpainted.png")
Audio Generation
Text-to-Speech
OpenAI TTS
from openai import OpenAI
from pathlib import Path
client = OpenAI()
speech_file_path = Path("output.mp3")
response = client.audio.speech.create(
model="tts-1-hd",
voice="nova", # alloy, echo, fable, onyx, nova, shimmer
input="Hello! This is a generated voice. AI can now speak naturally."
)
response.stream_to_file(speech_file_path)
ElevenLabs
from elevenlabs import generate, play, set_api_key
set_api_key("your-api-key")
audio = generate(
text="Welcome to the future of voice synthesis.",
voice="Bella",
model="eleven_monolingual_v1"
)
play(audio)
Music Generation
MusicGen (Meta)
from audiocraft.models import MusicGen
import torchaudio
model = MusicGen.get_pretrained('facebook/musicgen-medium')
# Generate music
descriptions = ['upbeat electronic dance music with strong bass']
duration = 30 # seconds
model.set_generation_params(duration=duration)
wav = model.generate(descriptions)
# Save
for idx, one_wav in enumerate(wav):
torchaudio.save(f'generated_{idx}.wav', one_wav.cpu(), model.sample_rate)
Video Generation
Stable Video Diffusion
from diffusers import StableVideoDiffusionPipeline
from PIL import Image
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid-xt",
torch_dtype=torch.float16,
variant="fp16"
)
pipe.to("cuda")
# Load initial image
image = Image.open("first_frame.png")
# Generate video frames
frames = pipe(image, decode_chunk_size=8, num_frames=25).frames[0]
# Save as video
from diffusers.utils import export_to_video
export_to_video(frames, "output_video.mp4", fps=7)
RunwayML Gen-2
API-based video generation:
import runwayml
client = runwayml.RunwayML(api_key="your-key")
# Text to video
task = client.image_generation.create(
prompt="a serene ocean at sunset with waves gently crashing",
model="gen2",
duration=4
)
# Wait for completion and download
video_url = task.get_output_url()
Multimodal Models
GPT-4 Vision
from openai import OpenAI
client = OpenAI()
response = client.chat.completions.create(
model="gpt-4-vision-preview",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://example.com/image.jpg"
}
}
]
}
],
max_tokens=300
)
print(response.choices[0].message.content)
Claude Vision
import anthropic
import base64
client = anthropic.Anthropic()
# Read and encode image
with open("image.jpg", "rb") as image_file:
image_data = base64.standard_b64encode(image_file.read()).decode("utf-8")
message = client.messages.create(
model="claude-sonnet-4-5-20250929",
max_tokens=1024,
messages=[
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": image_data,
},
},
{
"type": "text",
"text": "Describe this image in detail."
}
],
}
],
)
print(message.content[0].text)
LLaVA (Open Source)
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images
from PIL import Image
model_path = "liuhaotian/llava-v1.5-7b"
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name=get_model_name_from_path(model_path)
)
# Load and process image
image = Image.open("photo.jpg")
image_tensor = process_images([image], image_processor, model.config)
# Generate description
prompt = "Describe this image in detail."
outputs = model.generate(
image_tensor,
prompt,
max_new_tokens=512
)
Applications
1. Content Creation
# Automated blog writing pipeline
def generate_blog_post(topic):
# Research
outline_prompt = f"Create a detailed outline for a blog post about {topic}"
outline = llm.generate(outline_prompt)
# Write sections
sections = []
for section in outline.sections:
content = llm.generate(f"Write about: {section}")
sections.append(content)
# Generate image
image_prompt = f"blog header image for {topic}, professional, modern"
image = image_generator.generate(image_prompt)
return {
'outline': outline,
'content': sections,
'image': image
}
2. Education & Training
# Personalized tutoring
def create_lesson(topic, student_level, learning_style):
prompt = f"""
Create a {student_level}-level lesson on {topic} for a {learning_style} learner.
Include:
- Clear explanations with analogies
- 3 practice problems
- Visual aids descriptions
"""
lesson = llm.generate(prompt)
# Generate visual aids
visuals = [
image_gen.generate(desc)
for desc in lesson.visual_descriptions
]
return lesson, visuals
3. Software Development
# AI-assisted coding
def code_assistant(task_description, language="python"):
# Generate code
code_prompt = f"Write {language} code for: {task_description}"
code = llm.generate(code_prompt)
# Generate tests
test_prompt = f"Write unit tests for this code:\n{code}"
tests = llm.generate(test_prompt)
# Generate documentation
doc_prompt = f"Write comprehensive documentation for:\n{code}"
docs = llm.generate(doc_prompt)
return {
'code': code,
'tests': tests,
'docs': docs
}
4. Marketing & Advertising
# Campaign generation
def create_marketing_campaign(product, target_audience):
# Generate copy variations
copy_prompt = f"""
Create 5 ad copy variations for {product} targeting {target_audience}.
Each should be:
- Under 100 characters
- Compelling call-to-action
- Different emotional angle
"""
copies = llm.generate(copy_prompt)
# Generate visuals
for copy in copies:
visual_prompt = f"advertising image for: {copy}, {product}, professional photography"
image = image_gen.generate(visual_prompt)
return campaign
5. Data Augmentation
# Expand training dataset
def augment_dataset(original_data):
augmented = []
for item in original_data:
# Text augmentation
variations = llm.generate(
f"Create 5 paraphrases of: {item.text}"
)
augmented.extend(variations)
# Image augmentation (if applicable)
if item.image:
synthetic_images = image_gen.generate(
f"similar to: {item.image_description}"
)
augmented.extend(synthetic_images)
return augmented
6. Accessibility
# Multi-modal accessibility
def make_accessible(content):
if content.is_text():
# Text to speech
audio = tts.generate(content.text)
# Generate descriptive images
image = image_gen.generate(f"illustration of: {content.text}")
elif content.is_image():
# Image to text description
description = vision_model.describe(content.image)
# Text to speech
audio = tts.generate(description)
return {
'text': description,
'audio': audio,
'image': image
}
Best Practices
1. Prompt Engineering
# Good prompt structure
prompt = """
Role: You are an expert {domain} specialist
Task: {specific_task}
Context: {relevant_background}
Requirements:
- {requirement_1}
- {requirement_2}
- {requirement_3}
Format: {output_format}
"""
2. Temperature & Sampling
# Creative tasks: High temperature
creative_config = {
"temperature": 0.8,
"top_p": 0.9,
"top_k": 50
}
# Factual tasks: Low temperature
factual_config = {
"temperature": 0.2,
"top_p": 0.95,
"top_k": 40
}
3. Error Handling
def generate_with_retry(prompt, max_retries=3):
for attempt in range(max_retries):
try:
response = llm.generate(prompt)
# Validate response
if validate(response):
return response
except Exception as e:
if attempt == max_retries - 1:
raise
continue
return fallback_response
4. Cost Optimization
# Cache responses
from functools import lru_cache
@lru_cache(maxsize=1000)
def generate_cached(prompt):
return llm.generate(prompt)
# Batch requests
def generate_batch(prompts):
return llm.batch_generate(prompts)
# Use appropriate model
def select_model(task_complexity):
if task_complexity == "simple":
return "gpt-3.5-turbo" # Cheaper
else:
return "gpt-4" # More capable
Ethical Considerations
1. Content Authenticity
# Add watermarks to generated content
def generate_with_watermark(prompt):
content = llm.generate(prompt)
metadata = {
'generated_by': 'AI',
'model': 'gpt-4',
'timestamp': datetime.now(),
'watermark': True
}
return content, metadata
2. Bias Detection
# Check for biased outputs
def check_bias(generated_content):
bias_check_prompt = f"""
Analyze this content for potential bias:
{generated_content}
Check for:
- Gender bias
- Racial bias
- Cultural bias
- Age bias
"""
analysis = llm.generate(bias_check_prompt)
return analysis
3. Safety Filters
# Content filtering
def safe_generate(prompt):
# Check input
if contains_unsafe_content(prompt):
return "Request rejected: unsafe content"
# Generate
output = llm.generate(prompt)
# Check output
if contains_unsafe_content(output):
return "Generation failed: unsafe output"
return output
Future Trends
1. Multimodal Foundation Models
- Unified models handling text, image, audio, video
- Seamless cross-modal generation
2. Personalization
- Models adapting to individual user preferences
- Context-aware generation
3. Efficiency
- Smaller, faster models with comparable quality
- Edge deployment of generative models
4. Controllability
- Fine-grained control over generation
- Steering models toward specific outputs
5. Collaboration
- Human-AI co-creation workflows
- Interactive refinement systems
Resources
Learning
Tools
Communities
- r/StableDiffusion
- r/LocalLLaMA
- Discord: Stable Diffusion, Midjourney
- Twitter/X: AI researchers and practitioners
Conclusion
Generative AI is rapidly evolving, with new models and capabilities emerging constantly. Success comes from understanding the fundamentals, choosing appropriate tools, and applying ethical practices. Experiment, iterate, and stay updated with the latest developments.
Large Language Models (LLMs)
Overview
Large Language Models are transformer-based neural networks trained on massive text corpora to predict and generate human language. They've revolutionized AI with capabilities in translation, summarization, question-answering, and reasoning.
Architecture Basics
Transformers
Built on self-attention mechanism:
- Query-Key-Value: "What am I looking for?" -> "Where's the relevant info?" -> "Get the info"
- Multi-head Attention: Multiple attention patterns in parallel
- Feed-forward Networks: Non-linear transformations
- Layer Normalization: Stabilizes training
Scaling Laws
Performance improves predictably with:
- Model size (parameters): 7B -> 70B -> 700B
- Dataset size: More tokens = better performance
- Compute: More training = better convergence
Popular Models
| Model | Size | Training Data | Strengths |
|---|---|---|---|
| GPT-4 | ~1.7T params | ~13T tokens | Reasoning, coding, creative |
| Claude | ~100B params | High quality data | Instruction following, safety |
| Llama 2 | 7B-70B | 2T tokens | Open-source, efficient |
| Mistral | 7B-8x7B | 7T tokens | Fast, efficient |
| Palm 2 | ~340B | High quality | Reasoning, math |
Training Process
1. Pre-training
Objective: Predict next token
Input: "The cat sat on the"
Target: "mat"
Loss = -log P(mat | previous tokens)
Train on unlabeled internet text (unsupervised)
2. Supervised Fine-tuning
Input: "What is 2+2?"
Target: "2+2=4"
Train on labeled examples (supervised)
3. RLHF (Reinforcement Learning from Human Feedback)
1. Generate multiple responses
2. Humans rank by quality
3. Train reward model
4. Use reward to optimize policy
Key Concepts
Tokenization
Convert text to numbers:
"Hello world" -> [15339, 1159]
Embeddings
Represent tokens as vectors in semantic space:
king - man + woman ~= queen
Context Window
Maximum tokens model can consider:
- GPT-3: 2K tokens
- GPT-4: 32K - 128K tokens
- Claude: 100K+ tokens
- Llama 2: 4K tokens
Temperature
Controls randomness of output:
- 0: Deterministic (always same answer)
- 0.7: Balanced (varied but coherent)
- 1+: Creative (more random)
Prompting Techniques
1. Zero-Shot
Question: What is 2+2?
Answer: 4
2. Few-Shot
Question: What is 3+3?
Answer: 6
Question: What is 2+2?
Answer:
3. Chain-of-Thought
Q: If there are 3 apples and you add 2 more, how many are there?
A: Let me think step by step:
1. Start with 3 apples
2. Add 2 more
3. Total: 3 + 2 = 5
4. Role-Based
You are a helpful Python expert.
Q: How do I reverse a list?
A: [explanations as Python expert]
Limitations
Hallucinations
Making up false information confidently:
Q: What's the capital of Atlantis?
A: The capital is Poseidiopolis. (Made up!)
Knowledge Cutoff
No information beyond training data:
Q: Who won the 2025 World Cup?
A: I don't have info beyond April 2024.
Context Length
Can't process extremely long documents
Reasoning
Struggles with:
- Multi-step complex logic
- Mathematics (prone to errors)
- Counting tokens accurately
Fine-tuning Approaches
Full Fine-tuning
Update all parameters (expensive):
Memory: O(parameters)
Time: O(tokens)
LoRA (Low-Rank Adaptation)
Add small trainable matrices (efficient):
# Instead of: W' = W + delta_W
# Use: W' = W + A*B (where A, B << W)
QLoRA
Quantized LoRA (even more efficient):
- 4-bit quantization
- Reduces memory to ~6GB for 7B model
Applications
| Use Case | Technique | Example |
|---|---|---|
| Chat | Conversation history | ChatGPT |
| Code | In-context learning | GitHub Copilot |
| Search | Semantic ranking | Perplexity AI |
| Translation | Multilingual models | Google Translate |
| Summarization | Extractive/abstractive | Claude summarization |
Costs & Efficiency
API Usage
Pricing: $ per 1M tokens
GPT-4: $30 input, $60 output
Claude: $8 input, $24 output
Running Locally
Model Size | VRAM Needed | Speed
7B params | 16GB | Fast
13B params | 24GB | Medium
70B params | 80GB (GPU) | Slow
Evaluation Metrics
| Metric | What It Measures |
|---|---|
| Perplexity | How well model predicts text |
| BLEU | Translation quality |
| ROUGE | Summarization quality |
| Human Eval | Actual user satisfaction |
Best Practices
1. Prompt Engineering
X Bad: "Write code"
Checkmark Good: "Write Python function that takes list and returns sorted list in ascending order"
2. Breaking Complex Tasks
Instead of: "Analyze this company and give investment advice"
Try:
1. "Summarize this company's financials"
2. "What are the main risks?"
3. "What are growth opportunities?"
4. "Should we invest?"
3. Verification
Always verify facts from authoritative sources
ELI10
Imagine teaching a child language by:
- Reading millions of books
- Learning to predict next word
- Getting feedback on quality
- Adjusting understanding
That's basically how LLMs learn! They become really good at continuing conversations in natural human language.
The trick: They learn statistics of language, not true understanding. So they might confidently say wrong things (hallucinations).
Future Directions
- Multimodal: Understanding images + text + audio
- Long Context: Processing entire books
- Reasoning: Better at logic puzzles
- Efficiency: Running on phones/devices
- Robotics: Language guiding physical actions
Further Resources
Prompt Engineering
A comprehensive guide to crafting effective prompts for Large Language Models (LLMs).
Table of Contents
- Introduction
- Core Principles
- Fundamental Techniques
- Advanced Techniques
- Prompt Patterns
- Best Practices
- Common Pitfalls
- Examples by Task
Introduction
Prompt engineering is the practice of designing inputs to get desired outputs from LLMs. It's both an art and a science, requiring understanding of:
- How models process and interpret text
- What patterns yield consistent results
- How to balance specificity with flexibility
Core Principles
1. Clarity and Specificity
Be explicit about what you want:
❌ Bad: "Write about dogs"
✅ Good: "Write a 300-word informative article about the benefits of adopting rescue dogs, including health, cost, and emotional aspects."
2. Context Provision
Give the model necessary background:
❌ Bad: "What should I do?"
✅ Good: "I'm a Python developer with 3 years of experience. I want to transition into machine learning. What skills should I prioritize learning first?"
3. Format Specification
Define the desired output structure:
❌ Bad: "Tell me about the solar system"
✅ Good: "List the planets in our solar system in a markdown table with columns: Name, Distance from Sun (AU), and One Interesting Fact."
4. Role Assignment
Set the model's perspective:
"You are an experienced DevOps engineer. Explain Kubernetes deployments to a junior developer who has only worked with traditional hosting."
Fundamental Techniques
Zero-Shot Prompting
Direct instruction without examples:
Prompt: "Classify the sentiment of this review: 'The product arrived damaged but customer service was helpful.' Choose: positive, negative, or mixed."
Output: "mixed"
Few-Shot Prompting
Provide examples to guide the model:
Classify movie reviews as positive or negative:
Review: "A masterpiece of cinema!"
Sentiment: positive
Review: "Boring and predictable plot."
Sentiment: negative
Review: "Waste of time and money."
Sentiment: negative
Review: "Incredible performances by the cast."
Sentiment: positive
Review: "The special effects were amazing but the story was weak."
Sentiment:
Chain-of-Thought (CoT)
Encourage step-by-step reasoning:
Prompt: "A cafeteria has 23 apples. If they used 20 to make lunch and bought 6 more, how many apples do they have? Let's think step by step."
Output:
"Let's solve this step by step:
1. Starting apples: 23
2. Used for lunch: 23 - 20 = 3 apples remaining
3. Bought more: 3 + 6 = 9 apples
Answer: The cafeteria has 9 apples."
Zero-Shot Chain-of-Thought
Add "Let's think step by step" to enable reasoning:
Prompt: "If a train travels 120 miles in 2 hours, then speeds up and travels 180 miles in the next 2 hours, what's the average speed for the entire journey? Let's think step by step."
Self-Consistency
Generate multiple reasoning paths and choose the most consistent:
# Ask the same question with slight variations
prompts = [
"Calculate 15% tip on $47.50. Show your work.",
"What's a 15% tip on a $47.50 bill? Explain your calculation.",
"If my bill is $47.50 and I want to leave 15%, how much is the tip?"
]
# The most common answer is likely correct
Advanced Techniques
Tree of Thoughts (ToT)
Explore multiple reasoning branches:
Problem: Design a marketing campaign for a new eco-friendly water bottle.
Let's explore three different approaches:
Approach 1: Sustainability Focus
- Highlight environmental impact
- Partner with conservation organizations
- Target eco-conscious millennials
[Evaluate pros/cons]
Approach 2: Innovation Focus
- Emphasize unique design features
- Tech-forward marketing
- Target early adopters
[Evaluate pros/cons]
Approach 3: Health & Wellness Focus
- Connect to healthy lifestyle
- Partner with fitness influencers
- Target health-conscious consumers
[Evaluate pros/cons]
Now, let's combine the best elements...
ReAct (Reasoning + Acting)
Interleave reasoning with actions:
Task: Find information about the latest Python version
Thought: I need to find current Python version information
Action: Search for "latest Python version 2025"
Observation: Python 3.13 was released in October 2024
Thought: I should verify this is the most recent stable version
Action: Check Python.org official releases
Observation: Confirmed, Python 3.13 is the latest stable version
Answer: The latest Python version is 3.13, released in October 2024
Prompt Chaining
Break complex tasks into steps:
# Step 1: Research
prompt1 = "List 5 key features of electric vehicles vs gasoline cars"
# Step 2: Analyze (using output from step 1)
prompt2 = f"Given these EV features: {output1}, which three are most important for urban commuters?"
# Step 3: Synthesize
prompt3 = f"Based on these priority features: {output2}, write a 100-word recommendation"
Automatic Prompt Engineering (APE)
Let the model optimize its own prompts:
Meta-prompt: "I want to classify customer support tickets into categories: billing, technical, general inquiry. Generate 5 different prompts that would work well for this classification task."
Prompt Patterns
The Persona Pattern
"Act as [role] with [characteristics]. Your task is to [objective]."
Example:
"Act as a senior software architect with 15 years of experience in microservices. Review this code design and suggest improvements for scalability."
The Template Pattern
"[Action] about [topic] in [format] with [constraints]."
Example:
"Write about artificial intelligence in a blog post format with a friendly tone, 500 words max, aimed at non-technical readers."
The Constraint Pattern
"[Task]. You must [requirement 1]. You must [requirement 2]. You cannot [restriction]."
Example:
"Write a product description. You must include benefits, not just features. You must use active voice. You cannot use technical jargon."
The Refinement Pattern
Initial prompt → Generate → Critique → Revise
Example:
"Write a haiku about coding."
[output]
"Now critique this haiku for syllable count and imagery."
[critique]
"Revise the haiku based on the critique."
The Comparative Pattern
"Compare [A] and [B] in terms of [criteria 1], [criteria 2], and [criteria 3]. Present as [format]."
Example:
"Compare REST API and GraphQL in terms of performance, flexibility, and ease of use. Present as a comparison table."
The Instruction-Context-Format (ICF) Pattern
# Instruction
[What to do]
# Context
[Background information]
# Format
[How to structure the output]
Example:
# Instruction
Explain how photosynthesis works
# Context
The audience is 5th-grade students learning about plant biology for the first time
# Format
Use an analogy with a familiar concept, then provide 3-5 simple bullet points
Best Practices
1. Use Delimiters
Clearly separate different parts of your prompt:
Summarize the text delimited by triple quotes.
Text: """
[long text here]
"""
Requirements:
- 3 sentences maximum
- Highlight main argument
- Use neutral tone
2. Specify Output Format
"Provide your answer as a JSON object with the following structure:
{
"summary": "brief overview",
"key_points": ["point1", "point2", "point3"],
"recommendation": "actionable advice"
}"
3. Request Step-by-Step Thinking
"Before answering, explain your reasoning process. Then provide the final answer clearly labeled."
4. Use Examples Strategically
# For few-shot learning, provide diverse examples:
Input: "The cat sat on the mat" → Simple sentence
Input: "Although tired, she completed the marathon" → Complex sentence
Input: "Run!" → Imperative sentence
Input: "Is it raining?" → Interrogative sentence
Input: "What a beautiful day!" →
5. Iterate and Refine
# Version 1: Too vague
"Write code for a web scraper"
# Version 2: More specific
"Write Python code for a web scraper using BeautifulSoup"
# Version 3: Complete specification
"Write Python code using BeautifulSoup to scrape product names and prices from an e-commerce site. Include error handling for missing elements and rate limiting to respect the server."
6. Control Length
"Explain quantum entanglement in [50/100/200] words"
"Provide a [brief/moderate/detailed] explanation"
"Summarize in [2-3 sentences/one paragraph/300 words]"
7. Set the Temperature
Understand model parameters:
# Creative tasks (high temperature: 0.7-1.0)
{"temperature": 0.9}
# "Write a creative story about a time-traveling cat"
# Factual tasks (low temperature: 0.0-0.3)
{"temperature": 0.1}
# "What is the capital of France?"
# Balanced tasks (medium temperature: 0.4-0.6)
{"temperature": 0.5}
# "Explain the pros and cons of remote work"
Common Pitfalls
1. Ambiguity
❌ "Tell me about Python"
✅ "Explain Python's list comprehension syntax with 3 examples"
2. Conflicting Instructions
❌ "Write a detailed brief summary"
✅ "Write a summary in 2-3 sentences covering the main points"
3. Assuming Knowledge
❌ "Debug this code" [without context]
✅ "This Python function should sort a list but returns an error. Debug it: [code]. The error message is: [error]"
4. Overcomplicating
❌ [500-word prompt with 20 constraints]
✅ [Clear, focused prompt with 3-5 key requirements]
5. Not Testing Variations
Always try multiple phrasings:
- "List the benefits"
- "What are the advantages"
- "Explain why this is useful"
Examples by Task
Code Generation
Task: Create a Python function
Prompt:
"Write a Python function named 'calculate_statistics' that:
- Takes a list of numbers as input
- Returns a dictionary with: mean, median, mode, and standard deviation
- Handles edge cases (empty list, single value)
- Includes docstring with examples
- Uses only standard library modules"
Data Analysis
Task: Analyze sales data
Prompt:
"Given this sales data in CSV format:
[data]
Perform the following analysis:
1. Calculate total revenue by product category
2. Identify the top 3 performing products
3. Calculate month-over-month growth rate
4. Provide 3 actionable insights
Present findings in a structured format with clear headers."
Content Writing
Task: Write a blog post
Prompt:
"Write a 600-word blog post about 'The Future of Remote Work'
Structure:
- Engaging headline
- Hook in first paragraph
- 3 main sections with subheadings
- Include statistics or examples
- Conclude with actionable takeaway
Tone: Professional yet conversational
Audience: Mid-level professionals and managers
SEO keywords: remote work, hybrid model, workplace flexibility"
Summarization
Task: Summarize a technical document
Prompt:
"Summarize the following technical documentation:
[document]
Create two versions:
1. Executive Summary (100 words): High-level overview for non-technical stakeholders
2. Technical Summary (300 words): Key technical details for engineering team
Highlight any critical warnings or breaking changes."
Translation with Context
Task: Contextual translation
Prompt:
"Translate the following English text to Spanish:
'The system is down'
Context: This is an IT status message displayed to users during an outage.
Requirements:
- Use appropriate technical terminology
- Maintain professional tone
- Ensure clarity for non-technical users"
Code Review
Task: Review code quality
Prompt:
"Review this Python code for:
[code]
Evaluate:
1. Code quality and readability
2. Performance considerations
3. Potential bugs or edge cases
4. Security issues
5. Best practices adherence
Provide specific suggestions with code examples where applicable.
Rate each category from 1-5 and explain your ratings."
Question Answering
Task: Answer with citations
Prompt:
"Answer the following question using only information from the provided text. Quote relevant passages to support your answer.
Text: [document]
Question: [question]
Format:
- Direct answer (1-2 sentences)
- Supporting evidence (2-3 quoted passages)
- Confidence level (high/medium/low)"
Creative Writing
Task: Story generation
Prompt:
"Write a short story (500 words) with these elements:
Setting: Cyberpunk city in 2150
Protagonist: AI rights activist
Conflict: Choice between following the law or doing what's right
Theme: Question of consciousness and personhood
Tone: Noir detective style
Include:
- Vivid sensory details
- Internal monologue
- Unexpected twist ending"
Advanced Prompt Engineering
Meta-Prompting
"I need to create prompts for classifying customer emails. First, analyze what makes a good classification prompt, then generate 3 examples of effective prompts for this task."
Prompt Optimization Loop
initial_prompt = "Explain machine learning"
optimization_prompt = f"""
Original prompt: "{initial_prompt}"
This prompt is too vague. Improve it by:
1. Adding specific focus area
2. Defining target audience
3. Specifying depth of explanation
4. Setting output format
Provide an optimized version.
"""
System Prompts (API Usage)
# For chat-based models
system_prompt = """You are a Python expert specializing in data science.
Your responses should:
- Include working code examples
- Explain complex concepts simply
- Suggest best practices
- Warn about common pitfalls
- Use type hints and documentation"""
user_prompt = "How do I handle missing data in pandas?"
# API call structure
response = client.chat.completions.create(
model="gpt-4",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
)
Constitutional AI Prompting
Build in safety and ethical guidelines:
"[Task description]
Guidelines:
- Provide factual, unbiased information
- Acknowledge uncertainty when appropriate
- Avoid harmful or discriminatory content
- Cite sources when making factual claims
- Respect privacy and confidentiality"
Prompt Engineering Tools
LangChain Prompt Templates
from langchain import PromptTemplate
template = """
You are a {role} with expertise in {domain}.
Task: {task}
Context: {context}
Provide your response in {format} format.
"""
prompt = PromptTemplate(
input_variables=["role", "domain", "task", "context", "format"],
template=template
)
final_prompt = prompt.format(
role="data scientist",
domain="machine learning",
task="explain overfitting",
context="teaching beginners",
format="simple terms with examples"
)
Prompt Versioning
# Track prompt iterations
prompts = {
"v1.0": "Summarize this text",
"v1.1": "Summarize this text in 100 words",
"v1.2": "Summarize this text in 100 words, focusing on key insights",
"v2.0": "Provide a 100-word summary highlighting: 1) main argument, 2) supporting evidence, 3) conclusions"
}
Measuring Prompt Quality
Evaluation Criteria
- Consistency: Same prompt → similar outputs
- Accuracy: Outputs match expected results
- Efficiency: Minimal tokens for desired result
- Robustness: Works with variations in input
- Clarity: Unambiguous instructions
Testing Framework
def test_prompt(prompt, test_cases, model):
results = []
for test_input, expected_output in test_cases:
full_prompt = prompt.format(input=test_input)
actual_output = model.generate(full_prompt)
results.append({
'input': test_input,
'expected': expected_output,
'actual': actual_output,
'match': evaluate_match(expected_output, actual_output)
})
return results
Resources
Practice Platforms
Reading
- OpenAI Prompt Engineering Guide
- Anthropic Prompt Library
- Prompt Engineering Guide
- Research papers on prompting techniques
Communities
- r/PromptEngineering
- Discord servers for AI tools
- Twitter/X AI communities
Conclusion
Prompt engineering is an iterative process. Start simple, test thoroughly, and refine based on results. The key is understanding both your task requirements and how the model interprets instructions.
Remember: The best prompt is the one that consistently produces the results you need with minimal tokens and maximum clarity.
Llama Models - Meta AI
Complete guide to Meta's Llama family of open-source language models, from setup to fine-tuning and deployment.
Table of Contents
- Introduction
- Model Versions
- Installation & Setup
- Basic Usage
- Fine-tuning
- Quantization
- Inference Optimization
- Deployment
- Advanced Techniques
Introduction
Llama (Large Language Model Meta AI) is Meta's family of open-source foundation language models. Released as open-weights models, they've become the foundation for countless applications and fine-tuned variants.
Key Features
- Open Source: Freely available weights
- Strong Performance: Competitive with closed models
- Multiple Sizes: From 1B to 70B+ parameters
- Commercial Friendly: Permissive license
- Active Ecosystem: Huge community support
- Efficient: Optimized for deployment
Architecture
- Transformer-based: Decoder-only architecture
- RMSNorm: Root Mean Square Layer Normalization
- SwiGLU: Activation function
- Rotary Embeddings: Position encoding
- Grouped-Query Attention: Efficient attention mechanism
Model Versions
Llama 3.2 (Latest)
Released: September 2024
Llama 3.2 1B/3B (Edge Models)
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_name = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
# Chat
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is Python?"}
]
input_ids = tokenizer.apply_chat_template(
messages,
return_tensors="pt"
).to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=256,
temperature=0.7
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Features:
- 1B and 3B parameter versions
- Optimized for mobile and edge devices
- Multilingual support
- 128K context length
- Excellent for on-device inference
Llama 3.2 11B/90B (Vision Models)
from transformers import MllamaForConditionalGeneration, AutoProcessor
from PIL import Image
model = MllamaForConditionalGeneration.from_pretrained(
"meta-llama/Llama-3.2-11B-Vision-Instruct",
torch_dtype=torch.bfloat16,
device_map="auto"
)
processor = AutoProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision-Instruct")
# Load image
image = Image.open("photo.jpg")
# Create prompt
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What's in this image?"}
]
}
]
# Process and generate
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(image, input_text, return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=256)
print(processor.decode(output[0], skip_special_tokens=True))
Features:
- Multimodal (text + vision)
- 11B and 90B variants
- Image understanding
- Visual question answering
Llama 3.1
Released: July 2024
# 8B - Fast, efficient
model_name = "meta-llama/Llama-3.1-8B-Instruct"
# 70B - High capability
model_name = "meta-llama/Llama-3.1-70B-Instruct"
# 405B - Most capable (requires multiple GPUs)
model_name = "meta-llama/Llama-3.1-405B-Instruct"
Features:
- 128K context window
- Multilingual (8 languages)
- Tool use capabilities
- Improved reasoning
- 8B, 70B, and 405B sizes
Llama 3
Released: April 2024
# 8B
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
# 70B
model_name = "meta-llama/Meta-Llama-3-70B-Instruct"
Features:
- 8K context window
- Strong performance
- Better instruction following
- 8B and 70B sizes
Llama 2
Released: July 2023
# 7B
model_name = "meta-llama/Llama-2-7b-chat-hf"
# 13B
model_name = "meta-llama/Llama-2-13b-chat-hf"
# 70B
model_name = "meta-llama/Llama-2-70b-chat-hf"
Features:
- 4K context window
- 7B, 13B, and 70B sizes
- Still widely used
Model Comparison
| Model | Parameters | Context | VRAM (FP16) | Use Case |
|---|---|---|---|---|
| Llama 3.2 1B | 1B | 128K | 2GB | Edge/Mobile |
| Llama 3.2 3B | 3B | 128K | 6GB | Edge/Desktop |
| Llama 3.1 8B | 8B | 128K | 16GB | Standard |
| Llama 3.2 11B Vision | 11B | 128K | 22GB | Multimodal |
| Llama 3.1 70B | 70B | 128K | 140GB | High-end |
| Llama 3.2 90B Vision | 90B | 128K | 180GB | Vision tasks |
| Llama 3.1 405B | 405B | 128K | 810GB | Best quality |
Installation & Setup
Via Hugging Face Transformers
# Install dependencies
pip install transformers torch accelerate
# For quantization
pip install bitsandbytes
# For training
pip install peft datasets
Via Ollama (Easy Local Setup)
# Install Ollama
curl -fsSL https://ollama.com/install.sh | sh
# Pull model
ollama pull llama3.2
# Run
ollama run llama3.2
Python usage:
import requests
def query_ollama(prompt):
response = requests.post('http://localhost:11434/api/generate',
json={
"model": "llama3.2",
"prompt": prompt,
"stream": False
}
)
return response.json()['response']
result = query_ollama("What is machine learning?")
print(result)
Via llama.cpp (Efficient C++ Implementation)
# Clone and build
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp
make
# Download model (GGUF format)
# From Hugging Face or converted locally
# Run inference
./main -m models/llama-3.2-1B-Instruct-Q4_K_M.gguf -p "Hello, how are you?"
Python bindings:
pip install llama-cpp-python
from llama_cpp import Llama
llm = Llama(
model_path="models/llama-3.2-3B-Instruct-Q4_K_M.gguf",
n_ctx=2048,
n_gpu_layers=35 # Adjust for GPU
)
output = llm(
"Explain quantum computing",
max_tokens=256,
temperature=0.7,
top_p=0.95,
)
print(output['choices'][0]['text'])
Via vLLM (Production Inference)
pip install vllm
from vllm import LLM, SamplingParams
# Load model
llm = LLM(
model="meta-llama/Llama-3.2-3B-Instruct",
tensor_parallel_size=1
)
# Sampling parameters
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.95,
max_tokens=256
)
# Generate
prompts = ["What is AI?", "Explain Python"]
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(output.outputs[0].text)
Basic Usage
Simple Text Generation
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load model
model_name = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
# Generate
prompt = "Write a Python function to calculate factorial:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
do_sample=True
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(result)
Chat Format
# Proper chat formatting
messages = [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "What is the capital of France?"},
]
# Apply chat template
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# Generate
outputs = model.generate(
input_ids,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
print(response)
Multi-turn Conversation
conversation = [
{"role": "system", "content": "You are a helpful assistant."}
]
def chat(user_message):
# Add user message
conversation.append({"role": "user", "content": user_message})
# Generate response
input_ids = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(
input_ids,
max_new_tokens=256,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(
outputs[0][input_ids.shape[-1]:],
skip_special_tokens=True
)
# Add assistant response
conversation.append({"role": "assistant", "content": response})
return response
# Use
print(chat("What is Python?"))
print(chat("How do I install it?"))
print(chat("Give me a simple example."))
Streaming Generation
from transformers import TextIteratorStreamer
from threading import Thread
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
# Prepare input
messages = [{"role": "user", "content": "Write a short story about AI"}]
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
# Generate in thread
generation_kwargs = {
"input_ids": input_ids,
"max_new_tokens": 512,
"temperature": 0.8,
"streamer": streamer
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Stream output
for text in streamer:
print(text, end="", flush=True)
thread.join()
Fine-tuning
QLoRA Fine-tuning (Most Popular)
Efficient fine-tuning with quantization:
pip install transformers peft accelerate bitsandbytes datasets
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
import torch
# Load model with quantization
model_name = "meta-llama/Llama-3.2-3B-Instruct"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# Prepare model
model = prepare_model_for_kbit_training(model)
# LoRA configuration
lora_config = LoraConfig(
r=16, # Rank
lora_alpha=32, # Scaling factor
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Output: trainable params: ~16M / total: 3B (~0.5%)
# Prepare dataset
dataset = load_dataset("your-dataset")
def format_instruction(example):
text = f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['response']}"
return {"text": text}
dataset = dataset.map(format_instruction)
# Tokenize
def tokenize(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=512,
padding="max_length"
)
tokenized_dataset = dataset.map(tokenize, batched=True)
# Training arguments
training_args = TrainingArguments(
output_dir="./llama-finetuned",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
fp16=True,
logging_steps=10,
save_strategy="epoch",
optim="paged_adamw_8bit"
)
# Train
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
data_collator=lambda data: {
'input_ids': torch.stack([f['input_ids'] for f in data]),
'attention_mask': torch.stack([f['attention_mask'] for f in data]),
'labels': torch.stack([f['input_ids'] for f in data])
}
)
trainer.train()
# Save
model.save_pretrained("./llama-lora")
tokenizer.save_pretrained("./llama-lora")
Using Fine-tuned Model
from peft import PeftModel
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3B-Instruct",
torch_dtype=torch.float16,
device_map="auto"
)
# Load LoRA
model = PeftModel.from_pretrained(base_model, "./llama-lora")
# Generate
tokenizer = AutoTokenizer.from_pretrained("./llama-lora")
inputs = tokenizer("Your prompt here", return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=256)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Full Fine-tuning (Requires More Resources)
from transformers import Trainer, TrainingArguments
# Load model normally (no quantization)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3B-Instruct",
torch_dtype=torch.bfloat16,
device_map="auto"
)
# Training arguments
training_args = TrainingArguments(
output_dir="./llama-fullft",
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
learning_rate=1e-5,
bf16=True,
logging_steps=10,
save_strategy="epoch",
deepspeed="ds_config.json" # For multi-GPU
)
# Train
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"]
)
trainer.train()
Using Axolotl (Simplified Training)
# Install
git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip install -e .
Create config llama_qlora.yml:
base_model: meta-llama/Llama-3.2-3B-Instruct
model_type: LlamaForCausalLM
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
datasets:
- path: your-dataset
type: alpaca
num_epochs: 3
micro_batch_size: 4
gradient_accumulation_steps: 4
learning_rate: 0.0002
output_dir: ./llama-qlora-out
Train:
accelerate launch -m axolotl.cli.train llama_qlora.yml
Quantization
BitsAndBytes Quantization
from transformers import BitsAndBytesConfig
# 8-bit
bnb_config_8bit = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0
)
# 4-bit (QLoRA)
bnb_config_4bit = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # or "fp4"
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3B-Instruct",
quantization_config=bnb_config_4bit,
device_map="auto"
)
GGUF Quantization (llama.cpp)
# Convert to GGUF
python convert_hf_to_gguf.py \
--model-dir models/Llama-3.2-3B-Instruct \
--outfile llama-3.2-3b-instruct.gguf
# Quantize
./quantize \
llama-3.2-3b-instruct.gguf \
llama-3.2-3b-instruct-Q4_K_M.gguf \
Q4_K_M
Quantization formats:
Q4_0: 4-bit, fastest, lowest qualityQ4_K_M: 4-bit, good quality (recommended)Q5_K_M: 5-bit, better qualityQ8_0: 8-bit, high quality
GPTQ Quantization
pip install auto-gptq
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
# Quantize
quantize_config = BaseQuantizeConfig(
bits=4,
group_size=128,
desc_act=False
)
model = AutoGPTQForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3B-Instruct",
quantize_config=quantize_config
)
# Save
model.save_quantized("llama-3.2-3b-gptq")
# Load
model = AutoGPTQForCausalLM.from_quantized(
"llama-3.2-3b-gptq",
device_map="auto"
)
AWQ Quantization
pip install autoawq
from awq import AutoAWQForCausalLM
# Quantize
model = AutoAWQForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
model.quantize(tokenizer, quant_config={"zero_point": True, "q_group_size": 128})
model.save_quantized("llama-3.2-3b-awq")
# Load
model = AutoAWQForCausalLM.from_quantized("llama-3.2-3b-awq")
Inference Optimization
Flash Attention 2
pip install flash-attn --no-build-isolation
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3B-Instruct",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
device_map="auto"
)
Batch Inference
# Process multiple prompts efficiently
prompts = [
"What is Python?",
"Explain machine learning",
"How do computers work?"
]
# Tokenize with padding
inputs = tokenizer(
prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(model.device)
# Generate
outputs = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.pad_token_id
)
# Decode
results = tokenizer.batch_decode(outputs, skip_special_tokens=True)
for prompt, result in zip(prompts, results):
print(f"Q: {prompt}\nA: {result}\n")
KV Cache Optimization
# Enable static KV cache for faster inference
model.generation_config.cache_implementation = "static"
model.generation_config.max_length = 512
# Or use with generate
outputs = model.generate(
input_ids,
max_new_tokens=256,
use_cache=True,
cache_implementation="static"
)
TensorRT-LLM
# Build TensorRT engine
git clone https://github.com/NVIDIA/TensorRT-LLM.git
cd TensorRT-LLM
# Convert and build
python examples/llama/convert_checkpoint.py \
--model_dir models/Llama-3.2-3B-Instruct \
--output_dir ./trt_ckpt \
--dtype float16
trtllm-build \
--checkpoint_dir ./trt_ckpt \
--output_dir ./trt_engine \
--gemm_plugin float16
Deployment
FastAPI Server
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
app = FastAPI()
# Load model once at startup
model_name = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
class GenerateRequest(BaseModel):
prompt: str
max_tokens: int = 256
temperature: float = 0.7
@app.post("/generate")
async def generate(request: GenerateRequest):
inputs = tokenizer(request.prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
do_sample=True
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"response": result}
# Run: uvicorn server:app --host 0.0.0.0 --port 8000
vLLM Server
# Start server
vllm serve meta-llama/Llama-3.2-3B-Instruct \
--host 0.0.0.0 \
--port 8000 \
--tensor-parallel-size 1
# Client
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Llama-3.2-3B-Instruct",
"prompt": "What is AI?",
"max_tokens": 256
}'
Python client:
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="dummy"
)
response = client.completions.create(
model="meta-llama/Llama-3.2-3B-Instruct",
prompt="Explain quantum computing",
max_tokens=256
)
print(response.choices[0].text)
Text Generation Inference (TGI)
# Docker
docker run --gpus all --shm-size 1g -p 8080:80 \
-v $PWD/data:/data \
ghcr.io/huggingface/text-generation-inference:latest \
--model-id meta-llama/Llama-3.2-3B-Instruct
# Client
curl http://localhost:8080/generate \
-X POST \
-d '{"inputs":"What is Python?","parameters":{"max_new_tokens":256}}' \
-H 'Content-Type: application/json'
LangChain Integration
pip install langchain langchain-community
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# Load model
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3B-Instruct",
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
# Create pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
temperature=0.7
)
# LangChain LLM
llm = HuggingFacePipeline(pipeline=pipe)
# Create chain
template = "Question: {question}\n\nAnswer:"
prompt = PromptTemplate(template=template, input_variables=["question"])
chain = LLMChain(llm=llm, prompt=prompt)
# Use
result = chain.run("What is machine learning?")
print(result)
Advanced Techniques
Retrieval-Augmented Generation (RAG)
pip install langchain chromadb sentence-transformers
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFacePipeline
# Load documents
documents = ["Your document text here..."]
# Split text
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
texts = text_splitter.create_documents(documents)
# Create embeddings
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Create vector store
vectorstore = Chroma.from_documents(texts, embeddings)
# Create QA chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vectorstore.as_retriever()
)
# Query
query = "What does the document say about AI?"
result = qa_chain.run(query)
print(result)
Function Calling
import json
def get_current_weather(location: str, unit: str = "celsius"):
"""Get current weather for a location"""
# Simulated function
return {"location": location, "temperature": 22, "unit": unit}
# Define tools
tools = [
{
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "City name"},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}
},
"required": ["location"]
}
}
]
# System prompt
system_prompt = f"""You are a helpful assistant with access to tools.
Available tools: {json.dumps(tools, indent=2)}
When you need to use a tool, output JSON: {{"tool": "tool_name", "parameters": {{...}}}}
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": "What's the weather in Paris?"}
]
# Generate
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
outputs = model.generate(input_ids, max_new_tokens=256)
response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
# Parse and execute tool call
if "tool" in response:
tool_call = json.loads(response)
if tool_call["tool"] == "get_current_weather":
result = get_current_weather(**tool_call["parameters"])
print(f"Weather: {result}")
Constrained Generation
pip install outlines
import outlines
# Load model
model = outlines.models.transformers("meta-llama/Llama-3.2-1B-Instruct")
# JSON schema constraint
schema = """{
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"skills": {"type": "array", "items": {"type": "string"}}
}
}"""
generator = outlines.generate.json(model, schema)
result = generator("Generate a person profile:")
print(result)
# Regex constraint
phone_pattern = r"\d{3}-\d{3}-\d{4}"
generator = outlines.generate.regex(model, phone_pattern)
phone = generator("Generate a US phone number:")
print(phone)
Best Practices
1. Model Selection
# Choose based on requirements
model_selection = {
"mobile/edge": "meta-llama/Llama-3.2-1B-Instruct",
"desktop/low_vram": "meta-llama/Llama-3.2-3B-Instruct",
"standard": "meta-llama/Llama-3.1-8B-Instruct",
"high_quality": "meta-llama/Llama-3.1-70B-Instruct",
"vision": "meta-llama/Llama-3.2-11B-Vision-Instruct"
}
2. Prompt Templates
# Use consistent templates
SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant."
def format_chat(user_message, system=SYSTEM_PROMPT):
return [
{"role": "system", "content": system},
{"role": "user", "content": user_message}
]
3. Memory Management
import torch
import gc
def clear_memory():
gc.collect()
torch.cuda.empty_cache()
# After large operations
outputs = model.generate(...)
result = tokenizer.decode(outputs[0])
del outputs
clear_memory()
4. Error Handling
def safe_generate(prompt, max_retries=3):
for attempt in range(max_retries):
try:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=256)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except RuntimeError as e:
if "out of memory" in str(e) and attempt < max_retries - 1:
torch.cuda.empty_cache()
continue
raise
Resources
Official
Tools
Fine-tuning
Community
- r/LocalLLaMA
- Hugging Face Forums
- Discord communities
Conclusion
Llama models provide a powerful, open-source foundation for AI applications. Whether you're running a 1B model on a mobile device or deploying a 70B model in production, the ecosystem offers tools and techniques for every use case.
Key takeaways:
- Start small: Test with 1B/3B models first
- Quantize: Use 4-bit for efficient inference
- Fine-tune: QLoRA for custom domains
- Optimize: vLLM/TGI for production
- Monitor: Watch memory and performance
The open-source nature and active community make Llama models an excellent choice for both research and production applications.
Stable Diffusion
Complete guide to Stable Diffusion for image generation, from setup to advanced techniques.
Table of Contents
- Introduction
- Installation & Setup
- Model Versions
- Prompt Engineering
- Parameters
- Advanced Techniques
- Extensions & Tools
- Optimization
- Common Issues
Introduction
Stable Diffusion is an open-source text-to-image diffusion model capable of generating high-quality images from text descriptions. Unlike proprietary alternatives, it can run locally on consumer hardware.
Key Features
- Open Source: Free to use and modify
- Local Execution: Run on your own hardware
- Extensible: ControlNet, LoRA, extensions
- Fast: Optimized inference with various schedulers
- Flexible: Text-to-image, image-to-image, inpainting
Installation & Setup
Option 1: AUTOMATIC1111 WebUI (Most Popular)
# Clone repository
git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
cd stable-diffusion-webui
# Install (Linux/Mac)
./webui.sh
# Install (Windows)
# Double-click webui-user.bat
# With custom arguments
# Edit webui-user.sh or webui-user.bat:
export COMMANDLINE_ARGS="--xformers --medvram --api"
System Requirements:
- GPU: NVIDIA (8GB+ VRAM recommended)
- RAM: 16GB+ system RAM
- Storage: 10GB+ for models
Option 2: ComfyUI (Node-Based)
# Clone repository
git clone https://github.com/comfyanonymous/ComfyUI.git
cd ComfyUI
# Install dependencies
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip install -r requirements.txt
# Run
python main.py
# With arguments
python main.py --listen --port 8188
Option 3: Python Library (Diffusers)
pip install diffusers transformers accelerate torch torchvision
from diffusers import StableDiffusionPipeline
import torch
# Load model
model_id = "stabilityai/stable-diffusion-2-1"
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
# Enable optimizations
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
# Generate
prompt = "a beautiful landscape"
image = pipe(prompt).images[0]
image.save("output.png")
Option 4: Invoke AI
pip install invokeai
invokeai-configure
invokeai --web
Model Versions
Stable Diffusion 1.x
SD 1.4
# Download location
models/Stable-diffusion/sd-v1-4.ckpt
- Resolution: 512x512
- Training: LAION-2B subset
- Good for: General use
SD 1.5
# Most popular 1.x version
wget https://huggingface.co/runwayml/stable-diffusion-v1-5
- Improved over 1.4
- Massive ecosystem of fine-tunes
- Best model support
Stable Diffusion 2.x
SD 2.0
- Resolution: 768x768
- New text encoder (OpenCLIP)
- Better quality but different style
SD 2.1
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1",
torch_dtype=torch.float16
)
- Improvements over 2.0
- Recommended 2.x version
Stable Diffusion XL (SDXL)
from diffusers import StableDiffusionXLPipeline
# Base model
base = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16"
)
base.to("cuda")
# Refiner (optional, improves quality)
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
torch_dtype=torch.float16,
variant="fp16"
)
refiner.to("cuda")
# Generate
image = base(prompt="a futuristic city").images[0]
image = refiner(prompt="a futuristic city", image=image).images[0]
Features:
- Resolution: 1024x1024
- Higher quality
- Better text rendering
- Dual text encoders
- Requires more VRAM (8GB+)
Stable Diffusion 3
Latest version with improved architecture:
- Multimodal diffusion transformer
- Better prompt understanding
- Improved composition
Prompt Engineering
Basic Structure
[Subject] [Action/Scene] [Environment] [Lighting] [Style] [Quality]
Effective Prompts
Basic:
"a cat"
Better:
"a fluffy orange cat sitting on a windowsill"
Best:
"a fluffy orange tabby cat sitting on a wooden windowsill, looking outside at falling snow, soft natural lighting, cozy atmosphere, detailed fur texture, photorealistic, 4k, highly detailed"
Prompt Components
1. Subject
"portrait of a young woman"
"a medieval castle"
"a steampunk airship"
"cyberpunk street scene"
2. Action/Pose
"running through a field"
"sitting in contemplation"
"dancing under moonlight"
"reading a book by firelight"
3. Environment
"in a mystical forest"
"on a alien planet"
"in a Victorian library"
"at a bustling marketplace"
4. Lighting
"golden hour lighting"
"dramatic rim lighting"
"soft diffused light"
"neon lights reflecting on wet streets"
"volumetric fog with god rays"
5. Style
"oil painting style"
"anime art style"
"photorealistic"
"watercolor painting"
"digital art, trending on artstation"
"in the style of Greg Rutkowski"
6. Quality Boosters
"highly detailed"
"8k resolution"
"masterpiece"
"professional photography"
"award-winning"
"intricate details"
"sharp focus"
Negative Prompts
What to avoid in generation:
Negative Prompt:
"ugly, blurry, low quality, distorted, deformed, bad anatomy, poorly drawn, low resolution, watermark, signature, text, cropped, worst quality, jpeg artifacts"
Common Negative Terms:
- Quality:
blurry, low quality, pixelated, grainy - Anatomy:
bad anatomy, extra limbs, malformed hands - Artifacts:
watermark, text, signature, logo - Style:
cartoon (for photorealistic), realistic (for artistic)
Prompt Weighting
Emphasize or de-emphasize parts:
# AUTOMATIC1111 syntax
(keyword) # 1.1x weight
((keyword)) # 1.21x weight
(keyword:1.5) # 1.5x weight
[keyword] # 0.9x weight
Example:
"a (beautiful:1.3) landscape with (mountains:1.2) and [trees:0.8]"
Prompt Editing
Change prompts during generation:
# AUTOMATIC1111 syntax
[keyword1:keyword2:step]
Example:
"a [dog:cat:0.5]"
# Generates dog for first 50% of steps, then cat
"photo of a woman [smiling:serious:10]"
# Smiling for first 10 steps, then serious
Artist Styles
Reference famous artists:
"in the style of Van Gogh"
"by Greg Rutkowski"
"by Alphonse Mucha"
"by Simon Stalenhag"
"by Artgerm"
"by Ilya Kuvshinov"
Parameters
Core Parameters
Steps (num_inference_steps)
# Fewer steps = faster, less refined
image = pipe(prompt, num_inference_steps=20)
# More steps = slower, more refined
image = pipe(prompt, num_inference_steps=50)
Recommendations:
- Quick preview: 15-20 steps
- Standard quality: 25-35 steps
- High quality: 40-60 steps
- Diminishing returns after 60
CFG Scale (guidance_scale)
How closely to follow the prompt:
# Low CFG = creative, less adherence
image = pipe(prompt, guidance_scale=3.5)
# Medium CFG = balanced
image = pipe(prompt, guidance_scale=7.5)
# High CFG = strict adherence, may oversaturate
image = pipe(prompt, guidance_scale=15)
Recommendations:
- Creative/artistic: 5-7
- Balanced: 7-10
- Strict/detailed: 10-15
- Avoid: >20 (over-saturated)
Seed
Reproducible results:
# Random seed
image = pipe(prompt)
# Fixed seed for reproducibility
generator = torch.Generator("cuda").manual_seed(42)
image = pipe(prompt, generator=generator)
Sampler/Scheduler
Different algorithms for denoising:
from diffusers import (
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
DDIMScheduler
)
# Fast and high quality (recommended)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
pipe.scheduler.config
)
# More creative, varied
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config
)
# Stable, predictable
pipe.scheduler = DDIMScheduler.from_config(
pipe.scheduler.config
)
Popular Samplers:
- DPM++ 2M Karras: Fast, high quality (recommended)
- Euler a: Creative, varied results
- DDIM: Stable, reproducible
- UniPC: Very fast, good quality
- DPM++ SDE Karras: High quality, slower
Resolution
# SD 1.5 native: 512x512
image = pipe(prompt, width=512, height=512)
# SD 2.1 native: 768x768
image = pipe(prompt, width=768, height=768)
# SDXL native: 1024x1024
image = pipe(prompt, width=1024, height=1024)
# Portrait
image = pipe(prompt, width=512, height=768)
# Landscape
image = pipe(prompt, width=768, height=512)
Tips:
- Stick to multiples of 8 or 64
- Native resolution gives best results
- Higher resolution needs more VRAM
- Use upscaling for ultra-high resolution
Batch Settings
# Generate multiple images
images = pipe(
prompt,
num_images_per_prompt=4,
guidance_scale=7.5
).images
# Save all
for i, img in enumerate(images):
img.save(f"output_{i}.png")
Advanced Techniques
Image-to-Image
Transform existing images:
from diffusers import StableDiffusionImg2ImgPipeline
from PIL import Image
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1",
torch_dtype=torch.float16
).to("cuda")
# Load image
init_image = Image.open("input.jpg").convert("RGB")
init_image = init_image.resize((768, 768))
# Transform
prompt = "a fantasy castle, magical, highly detailed"
images = pipe(
prompt=prompt,
image=init_image,
strength=0.75, # 0=no change, 1=complete regeneration
guidance_scale=7.5,
num_inference_steps=50
).images[0]
images.save("transformed.png")
Strength Parameter:
- 0.1-0.3: Minor adjustments, preserve structure
- 0.4-0.6: Moderate changes, guided by original
- 0.7-0.9: Major changes, loose interpretation
- 1.0: Complete regeneration
Inpainting
Edit specific parts of images:
from diffusers import StableDiffusionInpaintPipeline
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16
).to("cuda")
# Load image and mask
image = Image.open("photo.png")
mask = Image.open("mask.png") # White = inpaint, Black = keep
prompt = "a red vintage car"
result = pipe(
prompt=prompt,
image=image,
mask_image=mask,
num_inference_steps=50,
guidance_scale=7.5
).images[0]
result.save("inpainted.png")
ControlNet
Precise control over generation:
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from PIL import Image
import cv2
import numpy as np
# Load ControlNet model
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-canny",
torch_dtype=torch.float16
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=controlnet,
torch_dtype=torch.float16
).to("cuda")
# Load image and create canny edge map
image = Image.open("input.jpg")
image = np.array(image)
canny_edges = cv2.Canny(image, 100, 200)
canny_edges = Image.fromarray(canny_edges)
# Generate with control
prompt = "a professional architectural photograph"
output = pipe(
prompt=prompt,
image=canny_edges,
num_inference_steps=30
).images[0]
ControlNet Models:
- Canny: Edge detection
- Depth: Depth map
- OpenPose: Human pose
- Scribble: Hand-drawn sketches
- Normal: Normal maps
- Segmentation: Semantic segmentation
- MLSD: Line detection (architecture)
LoRA (Low-Rank Adaptation)
Fine-tuned models with small file size:
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16
).to("cuda")
# Load LoRA
pipe.load_lora_weights("path/to/lora.safetensors")
# Generate with LoRA style
prompt = "a portrait in the style of <lora-trigger-word>"
image = pipe(prompt).images[0]
# Unload LoRA
pipe.unload_lora_weights()
Popular LoRA Types:
- Character/celebrity faces
- Art styles
- Concepts
- Objects/clothing
Textual Inversion
Custom concepts/embeddings:
# Load embedding
pipe.load_textual_inversion("path/to/embedding.pt", token="<special-token>")
# Use in prompt
prompt = "a photo of <special-token> in a forest"
image = pipe(prompt).images[0]
Upscaling
Increase resolution with detail:
from diffusers import StableDiffusionUpscalePipeline
# Load upscaler
upscaler = StableDiffusionUpscalePipeline.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler",
torch_dtype=torch.float16
).to("cuda")
# Load low-res image
low_res = Image.open("output_512.png")
# Upscale
prompt = "highly detailed, sharp, professional"
upscaled = upscaler(
prompt=prompt,
image=low_res,
num_inference_steps=50
).images[0]
upscaled.save("output_2048.png")
Upscaling Options:
- SD Upscale: Built-in SD upscaler
- Real-ESRGAN: Traditional upscaler
- Ultimate SD Upscale: Tiled upscaling
- ControlNet Tile: Detail-preserving upscale
Extensions & Tools
AUTOMATIC1111 Extensions
Install via Extensions tab or:
cd extensions
git clone [extension-repo-url]
Essential Extensions
ControlNet
git clone https://github.com/Mikubill/sd-webui-controlnet.git
Dynamic Prompts
git clone https://github.com/adieyal/sd-dynamic-prompts.git
- Wildcard support:
{red|blue|green} car - Combinatorial generation
Image Browser
git clone https://github.com/AlUlkesh/stable-diffusion-webui-images-browser.git
- Browse generated images
- Search by metadata
Cutoff
git clone https://github.com/hnmr293/sd-webui-cutoff.git
- Prevent color bleeding between subjects
Regional Prompter
git clone https://github.com/hako-mikan/sd-webui-regional-prompter.git
- Different prompts for image regions
Checkpoint Merging
Combine models:
from diffusers import StableDiffusionPipeline
import torch
# Load two models
pipe1 = StableDiffusionPipeline.from_pretrained("model1")
pipe2 = StableDiffusionPipeline.from_pretrained("model2")
# Merge (0.5 = 50/50 blend)
alpha = 0.5
for key in pipe1.unet.state_dict():
pipe1.unet.state_dict()[key] = (
alpha * pipe1.unet.state_dict()[key] +
(1 - alpha) * pipe2.unet.state_dict()[key]
)
# Save merged model
pipe1.save_pretrained("merged_model")
Prompt Matrix
Test multiple prompts:
# In AUTOMATIC1111
Prompt: a |red, blue, green| |car, house| in a forest
Generates:
- a red car in a forest
- a red house in a forest
- a blue car in a forest
- a blue house in a forest
- a green car in a forest
- a green house in a forest
Optimization
Memory Optimization
from diffusers import StableDiffusionPipeline
import torch
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16 # Half precision
).to("cuda")
# Enable memory optimizations
pipe.enable_attention_slicing() # Reduce memory
pipe.enable_vae_slicing() # Reduce VAE memory
pipe.enable_xformers_memory_efficient_attention() # Faster attention
# For very low VRAM (4GB)
pipe.enable_sequential_cpu_offload()
Speed Optimization
# Use faster scheduler
from diffusers import DPMSolverMultistepScheduler
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
pipe.scheduler.config
)
# Compile model (PyTorch 2.0+)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead")
# Reduce steps with quality scheduler
image = pipe(prompt, num_inference_steps=20) # vs 50 with others
VRAM Requirements
| Configuration | Minimum VRAM |
|---|---|
| SD 1.5 (512x512) | 4GB |
| SD 1.5 (512x512, optimized) | 2GB |
| SD 2.1 (768x768) | 6GB |
| SDXL (1024x1024) | 8GB |
| SDXL (1024x1024, optimized) | 6GB |
| ControlNet + SD | +2GB |
| Batch size 2 | +2GB per image |
Launch Arguments (AUTOMATIC1111)
# Basic optimization
--xformers # Memory-efficient attention
--medvram # Medium VRAM optimization
--lowvram # Low VRAM optimization
--no-half-vae # Fix black images on some GPUs
# API
--api # Enable API
--listen # Allow network connections
# Performance
--opt-sdp-attention # Scaled dot product attention
--no-gradio-queue # Disable queue
# Example combination
./webui.sh --xformers --medvram --api --no-half-vae
Common Issues
Black Images
# Solution: Disable half precision for VAE
--no-half-vae
Or in Python:
pipe.vae.to(torch.float32)
Out of Memory (OOM)
# Enable all optimizations
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
pipe.enable_sequential_cpu_offload()
# Reduce resolution
image = pipe(prompt, width=512, height=512)
# Reduce batch size
image = pipe(prompt, num_images_per_prompt=1)
Bad Hands/Anatomy
Negative prompt: "bad hands, bad anatomy, extra fingers, missing fingers, deformed hands, poorly drawn hands"
# Or use inpainting to fix
# Or use ControlNet OpenPose for guidance
Inconsistent Results
# Use fixed seed
generator = torch.Generator("cuda").manual_seed(42)
image = pipe(prompt, generator=generator)
# Use lower temperature sampler (DDIM instead of Euler a)
Prompt Not Working
- Check prompt weighting:
(keyword:1.3) - Use negative prompt to exclude unwanted elements
- Increase CFG scale
- Try different sampler
- Add quality boosters: "highly detailed, 8k"
Best Practices
1. Prompt Structure
[Quality] [Style] [Subject] [Action] [Environment] [Lighting] [Details]
Example:
"masterpiece, best quality, photorealistic, portrait of a young woman, smiling, in a sunlit garden, golden hour lighting, detailed facial features, professional photography, 8k uhd"
2. Iterative Refinement
# Start with low steps for preview
preview = pipe(prompt, num_inference_steps=15).images[0]
# Refine with more steps
final = pipe(prompt, num_inference_steps=50).images[0]
# Upscale for details
upscaled = upscale(final)
3. Seed Management
# Save seeds for good results
good_seeds = []
for seed in range(100):
gen = torch.Generator("cuda").manual_seed(seed)
image = pipe(prompt, generator=gen).images[0]
if is_good(image):
good_seeds.append(seed)
image.save(f"good_{seed}.png")
4. Negative Prompts Library
negative_prompts = {
'photorealistic': "anime, cartoon, drawing, painting, low quality",
'artistic': "photorealistic, photo, realistic, low quality",
'quality': "ugly, blurry, low quality, low resolution, pixelated",
'anatomy': "bad anatomy, extra limbs, poorly drawn, deformed",
'artifacts': "watermark, signature, text, logo, copyright"
}
# Combine as needed
negative = ", ".join([
negative_prompts['quality'],
negative_prompts['anatomy'],
negative_prompts['artifacts']
])
Resources
Models
- Hugging Face
- Civitai - Community models, LoRAs
- Stability AI
Tools
Learning
Communities
- Discord: Stable Diffusion
- Reddit: r/StableDiffusion
- Twitter/X: #StableDiffusion
Conclusion
Stable Diffusion offers incredible flexibility and power for image generation. Success comes from understanding the fundamentals, experimenting with parameters, and iterating on prompts. Start simple, learn the basics, then explore advanced techniques like ControlNet and LoRA for professional results.
Flux.1 - Black Forest Labs
Complete guide to Flux.1, the next-generation image generation model from the creators of Stable Diffusion.
Table of Contents
- Introduction
- Model Variants
- Installation & Setup
- Usage
- Prompt Engineering
- Parameters
- Comparison with Other Models
- Advanced Techniques
- Optimization
Introduction
Flux.1 is a state-of-the-art image generation model developed by Black Forest Labs, the team behind the original Stable Diffusion. Released in 2024, it represents a significant advancement in image quality, prompt adherence, and detail preservation.
Key Features
- Superior Image Quality: Enhanced detail and realism
- Better Prompt Understanding: More accurate interpretation
- Improved Text Rendering: Readable text in images
- Flexible Architecture: Multiple variants for different needs
- Advanced Control: Fine-grained control over generation
- Fast Inference: Optimized for speed
Model Architecture
- Flow Matching: Advanced diffusion technique
- Hybrid Architecture: Combines transformer and diffusion
- 12B Parameters: Larger than SD models
- Parallel Attention: Efficient processing
- Rotation Position Embeddings (RoPE): Better spatial understanding
Model Variants
Flux.1 [pro]
Commercial, API-only
import requests
API_URL = "https://api.bfl.ml/v1/flux-pro"
API_KEY = "your-api-key"
def generate_flux_pro(prompt):
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
payload = {
"prompt": prompt,
"width": 1024,
"height": 1024,
"steps": 30
}
response = requests.post(API_URL, json=payload, headers=headers)
return response.json()
# Generate
result = generate_flux_pro(
"a professional photograph of a modern office, natural lighting, detailed"
)
Features:
- Highest quality
- Best prompt adherence
- Commercial use allowed
- API access only
- Pay per generation
Best for:
- Professional work
- Commercial projects
- Maximum quality needs
Flux.1 [dev]
Non-commercial, open-weight
import torch
from diffusers import FluxPipeline
# Load model
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
# Generate
prompt = "a majestic lion in the savanna at sunset, highly detailed"
image = pipe(
prompt,
guidance_scale=3.5,
num_inference_steps=30,
height=1024,
width=1024,
).images[0]
image.save("flux_output.png")
Features:
- High quality
- Open weights
- Non-commercial license
- Requires Hugging Face auth
- Can run locally
Requirements:
- GPU: 24GB+ VRAM (recommended)
- RAM: 32GB+ system RAM
- Storage: ~30GB for model
Best for:
- Research and development
- Personal projects
- Learning and experimentation
Flux.1 [schnell]
Apache 2.0 license, fastest
from diffusers import FluxPipeline
import torch
# Load schnell variant
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
# Fast generation (1-4 steps)
prompt = "a portrait of a person, professional photography"
image = pipe(
prompt,
num_inference_steps=4, # Very few steps needed
guidance_scale=0.0, # No guidance needed
height=1024,
width=1024,
).images[0]
image.save("schnell_output.png")
Features:
- Very fast (1-4 steps)
- Good quality
- Apache 2.0 license
- Commercial use allowed
- Lower VRAM requirements
Best for:
- Real-time applications
- High-volume generation
- Commercial projects
- Resource-constrained environments
Installation & Setup
Option 1: Diffusers (Recommended)
# Install dependencies
pip install diffusers transformers accelerate torch
# Install from latest
pip install git+https://github.com/huggingface/diffusers.git
from diffusers import FluxPipeline
import torch
# Authenticate with Hugging Face
from huggingface_hub import login
login(token="your_hf_token")
# Load model
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload() # Save VRAM
# Generate
image = pipe("a beautiful landscape").images[0]
Option 2: ComfyUI
# Update ComfyUI
cd ComfyUI
git pull
# Download Flux models to:
# models/unet/flux1-dev.safetensors
# models/unet/flux1-schnell.safetensors
# Download CLIP and T5 encoders to:
# models/clip/clip_l.safetensors
# models/clip/t5xxl_fp16.safetensors
# Download VAE to:
# models/vae/ae.safetensors
Option 3: AUTOMATIC1111 (via extension)
cd extensions
git clone https://github.com/XLabs-AI/x-flux-comfyui.git
# Restart WebUI
Hardware Requirements
| Variant | Minimum VRAM | Recommended VRAM | Storage |
|---|---|---|---|
| Schnell | 12GB | 16GB | 30GB |
| Dev | 16GB | 24GB | 30GB |
| Pro | N/A (API) | N/A (API) | N/A |
Optimizations:
- bfloat16: Reduces VRAM by ~50%
- CPU offload: Reduces VRAM usage further
- Quantization: 8-bit or 4-bit for lower VRAM
Usage
Basic Generation
from diffusers import FluxPipeline
import torch
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
# Simple generation
prompt = "a serene mountain lake at sunrise"
image = pipe(prompt).images[0]
image.save("output.png")
With Parameters
image = pipe(
prompt="a futuristic city with flying cars, neon lights, cyberpunk",
height=1024,
width=1024,
num_inference_steps=30,
guidance_scale=3.5,
max_sequence_length=256,
).images[0]
Batch Generation
# Multiple images from one prompt
images = pipe(
prompt="a cute cat",
num_images_per_prompt=4,
num_inference_steps=30,
).images
for i, img in enumerate(images):
img.save(f"cat_{i}.png")
Seed Control
# Fixed seed for reproducibility
generator = torch.Generator("cuda").manual_seed(42)
image = pipe(
prompt="a magical forest",
generator=generator,
num_inference_steps=30,
).images[0]
Memory-Efficient Generation
# For lower VRAM
pipe.enable_model_cpu_offload()
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
# Generate
image = pipe(
prompt="a detailed landscape",
height=1024,
width=1024,
).images[0]
Prompt Engineering
Prompt Structure
Flux.1 has excellent prompt understanding. Use natural language:
# Simple and effective
prompt = "a portrait of a woman with red hair, wearing a blue dress, in a garden"
# Detailed
prompt = """
a professional photograph of a young woman with flowing red hair,
wearing an elegant blue silk dress, standing in a lush garden
with blooming roses, soft natural lighting, golden hour,
depth of field, bokeh background, shot on Canon EOS R5
"""
# With style
prompt = """
oil painting of a medieval knight in full armor,
standing on a cliff overlooking the ocean at sunset,
dramatic lighting, renaissance art style,
highly detailed, masterpiece
"""
Natural Language
Flux excels with conversational prompts:
prompts = [
"Show me a cat wearing sunglasses at the beach",
"Create an image of a steampunk airship flying over Victorian London",
"Paint a serene Japanese garden in autumn with falling maple leaves",
"Design a futuristic sports car that looks fast even when standing still"
]
Text in Images
Flux.1 can render text (unlike most other models):
# Text rendering
prompt = '''
a modern cafe storefront with a neon sign that says "COFFEE SHOP",
rainy evening, reflections on wet pavement, cinematic lighting
'''
# Book cover
prompt = '''
a fantasy book cover with the title "The Dragon's Tale"
written in elegant golden letters at the top,
featuring a majestic dragon flying over mountains
'''
# Product mockup
prompt = '''
a white t-shirt with the text "FLUX.1" printed in bold black letters,
product photography, plain white background, professional lighting
'''
Aspect Ratios
# Portrait
image = pipe(prompt, height=1344, width=768).images[0]
# Landscape
image = pipe(prompt, height=768, width=1344).images[0]
# Square
image = pipe(prompt, height=1024, width=1024).images[0]
# Cinematic
image = pipe(prompt, height=576, width=1024).images[0]
# Ultra-wide
image = pipe(prompt, height=512, width=1536).images[0]
Prompt Tips
- Be Specific: More detail = better results
- Natural Language: Write as you would describe to a person
- Quality Terms: "professional", "detailed", "high quality"
- Style References: "photograph", "oil painting", "digital art"
- Lighting: "golden hour", "dramatic lighting", "soft light"
- Camera/Lens: "50mm lens", "wide angle", "macro"
Example Prompts
# Photorealistic
prompt = """
a cinematic photograph of a lone astronaut standing on mars,
red desert landscape, distant sun on horizon,
dust particles in air, dramatic lighting,
shot on ARRI Alexa, anamorphic lens
"""
# Artistic
prompt = """
watercolor painting of a coastal village,
Mediterranean architecture, boats in harbor,
soft pastel colors, impressionist style,
painted by Claude Monet
"""
# Product
prompt = """
professional product photography of a luxury watch,
silver metal band, blue dial face,
on marble surface with dramatic side lighting,
reflections, 8k resolution, advertising quality
"""
# Character
prompt = """
character design of a cyberpunk hacker,
purple mohawk, neon goggles, leather jacket with patches,
detailed facial features, full body illustration,
concept art style, trending on artstation
"""
# Architecture
prompt = """
modern minimalist house in forest setting,
large glass windows, wooden exterior,
surrounded by tall pine trees, morning mist,
architectural photography, professional real estate photo
"""
Parameters
num_inference_steps
Number of denoising steps:
# Schnell: 1-4 steps (optimized for speed)
image = pipe(prompt, num_inference_steps=4).images[0]
# Dev: 20-50 steps (balance)
image = pipe(prompt, num_inference_steps=30).images[0]
# Pro: API manages automatically
Recommendations:
- Schnell: 1-4 (4 recommended)
- Dev: 20-30 (30 recommended)
- More steps = better quality but slower
guidance_scale
How closely to follow the prompt:
# Schnell: 0.0 (no guidance needed)
image = pipe(prompt, guidance_scale=0.0).images[0]
# Dev: 3.0-5.0 (3.5 recommended)
image = pipe(prompt, guidance_scale=3.5).images[0]
Flux uses lower guidance than SD:
- SD typical: 7-10
- Flux typical: 3-5
max_sequence_length
Token limit for prompt:
# Standard
image = pipe(prompt, max_sequence_length=256).images[0]
# Long prompts
image = pipe(prompt, max_sequence_length=512).images[0]
Resolution
# Standard resolutions (in pixels)
resolutions = {
"square": (1024, 1024),
"portrait": (768, 1344),
"landscape": (1344, 768),
"wide": (1536, 640),
"tall": (640, 1536),
}
# Use
image = pipe(
prompt,
height=resolutions["landscape"][0],
width=resolutions["landscape"][1]
).images[0]
Notes:
- Keep dimensions divisible by 16
- Total pixels should be ~1MP for best results
- Higher resolutions need more VRAM
Comparison with Other Models
Flux.1 vs Stable Diffusion
| Feature | Flux.1 | SD 1.5 | SDXL |
|---|---|---|---|
| Image Quality | ⭐⭐⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐⭐ |
| Prompt Adherence | ⭐⭐⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐⭐ |
| Text Rendering | ⭐⭐⭐⭐⭐ | ⭐ | ⭐⭐ |
| Speed (Dev) | ⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐ |
| Speed (Schnell) | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐ |
| VRAM Usage | ⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐ |
| Ecosystem | ⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐⭐⭐ |
| License | Varies | Open | Open |
Quality Comparison
# Same prompt across models
prompt = "a detailed portrait of a person with glasses"
# Flux.1 Dev
flux_image = flux_pipe(prompt, num_inference_steps=30).images[0]
# Result: High detail, accurate glasses, natural lighting
# SDXL
sdxl_image = sdxl_pipe(prompt, num_inference_steps=30).images[0]
# Result: Good quality, some artifacts
# SD 1.5
sd15_image = sd15_pipe(prompt, num_inference_steps=30).images[0]
# Result: Lower quality, potential distortions
Strengths of Flux.1
- Superior Detail: Finer details in textures, faces, objects
- Better Composition: More coherent scene layouts
- Text Rendering: Can actually render readable text
- Prompt Understanding: Better interpretation of complex prompts
- Natural Images: More photorealistic when requested
Strengths of Stable Diffusion
- Ecosystem: Vast library of models, LoRAs, tools
- VRAM Efficiency: Runs on lower-end hardware
- Community: Large community, extensive documentation
- Extensions: ControlNet, regional prompting, etc.
- Customization: Easy to fine-tune and merge
When to Use Each
Use Flux.1 when:
- Maximum quality is priority
- Need text in images
- Want natural, detailed results
- Have adequate hardware
- Creating professional content
Use Stable Diffusion when:
- Need specific styles (anime, etc.)
- Want to use LoRAs/embeddings
- Limited VRAM (<12GB)
- Need extensive control (ControlNet)
- Large existing workflow
Advanced Techniques
Image-to-Image (via Diffusers)
from diffusers import FluxImg2ImgPipeline
from PIL import Image
# Load pipeline
pipe = FluxImg2ImgPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
# Load input image
init_image = Image.open("input.jpg").convert("RGB")
# Transform
prompt = "transform into an oil painting, artistic style"
image = pipe(
prompt=prompt,
image=init_image,
strength=0.75,
num_inference_steps=30,
guidance_scale=3.5
).images[0]
ControlNet (via third-party)
# Note: Official ControlNet not yet released
# Community implementations available
# Example with X-Labs implementation
from flux_control import FluxControlNetPipeline
pipe = FluxControlNetPipeline.from_pretrained(
"XLabs-AI/flux-controlnet-canny",
torch_dtype=torch.bfloat16
)
# Use canny edge detection
control_image = generate_canny(input_image)
output = pipe(prompt, control_image=control_image).images[0]
LoRA Fine-tuning
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
)
# Load LoRA (when available)
pipe.load_lora_weights("path/to/flux-lora.safetensors")
# Generate with LoRA style
prompt = "a portrait in the custom style"
image = pipe(prompt).images[0]
Batching for Efficiency
# Generate multiple variations
prompts = [
"a red car",
"a blue car",
"a green car",
"a yellow car"
]
images = []
for prompt in prompts:
image = pipe(prompt, num_inference_steps=30).images[0]
images.append(image)
# Or use batch processing if memory allows
Optimization
Memory Optimization
from diffusers import FluxPipeline
import torch
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16
)
# Enable CPU offloading
pipe.enable_model_cpu_offload()
# Enable VAE optimizations
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
# For extreme memory savings
pipe.enable_sequential_cpu_offload()
Speed Optimization
# Use Schnell for speed
pipe_schnell = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16
)
# Compile for faster inference (PyTorch 2.0+)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead")
# Use fewer steps
image = pipe_schnell(prompt, num_inference_steps=4).images[0]
Quantization
# 8-bit quantization
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16
)
Multi-GPU
# Distribute across GPUs
from accelerate import PartialState
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16
)
distributed_state = PartialState()
pipe.to(distributed_state.device)
API Usage (Flux Pro)
REST API
import requests
import base64
from io import BytesIO
from PIL import Image
API_URL = "https://api.bfl.ml/v1/flux-pro"
API_KEY = "your-api-key"
def generate_flux_pro(
prompt,
width=1024,
height=1024,
steps=30,
guidance=3.5
):
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
payload = {
"prompt": prompt,
"width": width,
"height": height,
"steps": steps,
"guidance": guidance
}
response = requests.post(API_URL, json=payload, headers=headers)
if response.status_code == 200:
image_data = response.json()["image"]
image = Image.open(BytesIO(base64.b64decode(image_data)))
return image
else:
raise Exception(f"API Error: {response.text}")
# Generate
image = generate_flux_pro(
"a beautiful sunset over mountains",
width=1344,
height=768
)
image.save("pro_output.png")
Async API
import asyncio
import aiohttp
async def generate_async(prompt):
async with aiohttp.ClientSession() as session:
headers = {"Authorization": f"Bearer {API_KEY}"}
payload = {"prompt": prompt}
async with session.post(API_URL, json=payload, headers=headers) as resp:
return await resp.json()
# Use
image_data = asyncio.run(generate_async("a futuristic city"))
Tips & Best Practices
1. Prompt Quality
# Good prompts for Flux
good_prompts = [
"a cinematic photograph of [subject], [details], [lighting], [camera]",
"an oil painting of [scene], [style], by [artist]",
"product photography of [item], [background], professional lighting",
"character design of [character], [details], concept art"
]
2. Iteration Strategy
# Start with Schnell for quick iterations
quick_pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16
)
# Iterate quickly
for variation in range(5):
gen = torch.Generator("cuda").manual_seed(variation)
preview = quick_pipe(
prompt,
num_inference_steps=4,
generator=gen
).images[0]
preview.save(f"preview_{variation}.png")
# Refine winner with Dev
final = dev_pipe(
final_prompt,
num_inference_steps=30,
generator=torch.Generator("cuda").manual_seed(winning_seed)
).images[0]
3. VRAM Management
# Monitor VRAM
import torch
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
# Clear cache between generations
torch.cuda.empty_cache()
4. Prompt Templates
templates = {
"portrait": "{subject}, {expression}, {clothing}, {background}, portrait photography, {lighting}",
"landscape": "{location}, {time_of_day}, {weather}, {style}, landscape photography",
"product": "product photography of {product}, {surface}, {lighting}, professional, commercial",
"artistic": "{style} of {subject}, {details}, by {artist}, masterpiece"
}
# Use
prompt = templates["portrait"].format(
subject="a young woman",
expression="slight smile",
clothing="elegant dress",
background="bokeh lights",
lighting="soft natural light"
)
Resources
Official
Community
- r/FluxAI
- Hugging Face Discussions
- Discord communities
Tools
Learning
- Flux.1 Paper
- Comparison benchmarks
- Community prompts
Conclusion
Flux.1 represents a significant leap in image generation quality. While it requires more resources than Stable Diffusion, the results are often worth it for professional applications. The Schnell variant offers excellent speed-to-quality ratio, while Dev provides maximum quality for local generation.
Key takeaways:
- Schnell: Fast, commercial-friendly, good quality
- Dev: Best local quality, non-commercial
- Pro: Highest quality, API-only, commercial
Choose based on your needs, hardware, and use case. Experiment with natural language prompts and leverage Flux's superior understanding for best results.
ComfyUI
https://github.com/comfyanonymous/ComfyUI
https://docs.comfy.org/get_started/manual_install
git clone https://github.com/comfyanonymous/ComfyUI.git https://comfyui-wiki.com/tutorial/advanced/flux1-comfyui-guide-workflow-and-examples
Fine-Tuning
Fine-tuning is the process of taking a pre-trained model and further training it on a specific task or dataset.
Overview
Fine-tuning adapts a general-purpose model to a specific domain or task with much less data and compute than training from scratch.
Approaches:
- Full fine-tuning: Update all parameters
- Parameter-efficient: Update subset (LoRA, adapters)
- Few-shot prompting: No parameter updates
When to Fine-Tune
✅ Good use cases:
- Domain-specific language (medical, legal)
- Specific output format requirements
- Improved performance on narrow tasks
- Style adaptation
❌ Bad use cases:
- General knowledge (use prompting)
- Limited data (< 100 examples)
- When prompting works well enough
Fine-Tuning Process
# Example with Hugging Face
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer
# 1. Load pre-trained model
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# 2. Prepare dataset
train_dataset = load_dataset("your_dataset")
# 3. Configure training
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
learning_rate=2e-5,
)
# 4. Train
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)
trainer.train()
LoRA (Low-Rank Adaptation)
from peft import LoraConfig, get_peft_model
# Configure LoRA
config = LoraConfig(
r=8, # Rank
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.1,
)
# Apply LoRA to model
model = get_peft_model(model, config)
model.print_trainable_parameters()
# trainable params: 0.1% (vs 100% for full fine-tuning)
Data Preparation
{"prompt": "Translate to French: Hello", "completion": "Bonjour"}
{"prompt": "Translate to French: Goodbye", "completion": "Au revoir"}
{"prompt": "Translate to French: Thank you", "completion": "Merci"}
Best Practices
- Start with quality data: 100-1000 high-quality examples
- Use parameter-efficient methods: LoRA for large models
- Monitor overfitting: Validate on held-out data
- Experiment with hyperparameters: Learning rate, batch size
- Evaluate systematically: Don't just rely on loss
- Consider data augmentation: Increase training data
- Version control: Track model versions and data
Evaluation
from sklearn.metrics import accuracy_score
predictions = model.generate(test_inputs)
accuracy = accuracy_score(test_labels, predictions)
print(f"Accuracy: {accuracy}")
Fine-tuning enables customization of powerful pre-trained models for specific applications with minimal resources.
Cloud Computing Overview
Table of Contents
- Introduction
- Cloud Service Models
- Cloud Deployment Models
- Major Cloud Providers
- Common Cloud Services
- Cloud Architecture Patterns
- Cost Optimization
- Security Best Practices
- Choosing a Cloud Provider
Introduction
Cloud computing is the delivery of computing services—including servers, storage, databases, networking, software, analytics, and intelligence—over the Internet ("the cloud") to offer faster innovation, flexible resources, and economies of scale.
Key Benefits
- Cost Savings: Pay only for what you use (OpEx vs CapEx)
- Scalability: Scale up or down based on demand
- Performance: Access to latest hardware and global infrastructure
- Speed: Deploy resources in minutes
- Reliability: Data backup, disaster recovery, business continuity
- Security: Enterprise-grade security features
Cloud Service Models
┌─────────────────────────────────────────────────────────────┐
│ Cloud Service Models │
├─────────────────────────────────────────────────────────────┤
│ │
│ IaaS (Infrastructure as a Service) │
│ ├─ You Manage: Applications, Data, Runtime, Middleware, OS │
│ └─ Provider Manages: Virtualization, Servers, Storage, Net │
│ │
│ PaaS (Platform as a Service) │
│ ├─ You Manage: Applications, Data │
│ └─ Provider Manages: Runtime, Middleware, OS, Infra │
│ │
│ SaaS (Software as a Service) │
│ ├─ You Manage: Data/Configuration │
│ └─ Provider Manages: Everything else │
│ │
│ FaaS (Function as a Service / Serverless) │
│ ├─ You Manage: Code/Functions │
│ └─ Provider Manages: Everything else + Auto-scaling │
└─────────────────────────────────────────────────────────────┘
IaaS - Infrastructure as a Service
Examples: AWS EC2, Azure VMs, Google Compute Engine
Use Cases:
- Hosting websites and web applications
- High-performance computing
- Big data analysis
- Backup and recovery
Control Level: High Management Overhead: High
PaaS - Platform as a Service
Examples: AWS Elastic Beanstalk, Azure App Service, Google App Engine
Use Cases:
- Application development and deployment
- API development and management
- Business analytics/intelligence
Control Level: Medium Management Overhead: Medium
SaaS - Software as a Service
Examples: Gmail, Office 365, Salesforce, Dropbox
Use Cases:
- Email and collaboration
- CRM and ERP systems
- Productivity applications
Control Level: Low Management Overhead: Low
FaaS - Function as a Service
Examples: AWS Lambda, Azure Functions, Google Cloud Functions
Use Cases:
- Event-driven applications
- Real-time file processing
- Scheduled tasks
- Microservices
Control Level: Low (code only) Management Overhead: Very Low
Cloud Deployment Models
Public Cloud
- Resources owned and operated by third-party provider
- Services delivered over the internet
- Examples: AWS, Azure, GCP
Pros: Cost-effective, scalable, no maintenance Cons: Less control, potential security concerns
Private Cloud
- Infrastructure used exclusively by a single organization
- Can be hosted on-premises or by third party
Pros: More control, enhanced security, compliance Cons: Higher cost, maintenance overhead
Hybrid Cloud
- Combination of public and private clouds
- Data and applications shared between them
Pros: Flexibility, cost optimization, compliance options Cons: Complexity, integration challenges
Multi-Cloud
- Using multiple cloud providers simultaneously
- Avoid vendor lock-in
Pros: Best-of-breed services, redundancy Cons: Increased complexity, management overhead
Major Cloud Providers
Comparison Matrix
┌────────────────┬──────────────┬──────────────┬──────────────┐
│ Feature │ AWS │ Azure │ GCP │
├────────────────┼──────────────┼──────────────┼──────────────┤
│ Market Share │ ~32% │ ~23% │ ~10% │
│ Launch Year │ 2006 │ 2010 │ 2008 │
│ Regions │ 30+ │ 60+ │ 35+ │
│ Services │ 200+ │ 200+ │ 100+ │
│ Strengths │ Maturity │ Enterprise │ ML/Data │
│ │ Breadth │ Integration │ Analytics │
│ Best For │ Startups │ .NET/Windows │ Big Data │
│ │ Flexibility │ Hybrid │ ML/AI │
└────────────────┴──────────────┴──────────────┴──────────────┘
AWS (Amazon Web Services)
- Founded: 2006
- Market Leader: Largest market share
- Strengths: Broad service portfolio, mature ecosystem, extensive documentation
- Popular Services: EC2, S3, Lambda, RDS, DynamoDB
Microsoft Azure
- Founded: 2010
- Second Largest: Strong enterprise presence
- Strengths: Hybrid cloud, Windows/Microsoft integration, Active Directory
- Popular Services: Virtual Machines, Blob Storage, Azure Functions, SQL Database
Google Cloud Platform (GCP)
- Founded: 2008
- Third Largest: Growing rapidly
- Strengths: Data analytics, machine learning, Kubernetes (GKE)
- Popular Services: Compute Engine, Cloud Storage, BigQuery, Cloud Functions
Other Providers
- IBM Cloud: Enterprise focus, AI (Watson)
- Oracle Cloud: Database workloads
- Alibaba Cloud: Asia-Pacific region
- DigitalOcean: Simple, developer-friendly
Common Cloud Services
Compute Services
Service Type AWS Azure GCP
─────────────────────────────────────────────────────────────
Virtual Machines EC2 Virtual Machines Compute Engine
Containers ECS/EKS/Fargate Container Inst. GKE/Cloud Run
Serverless Lambda Functions Cloud Functions
Auto Scaling Auto Scaling VM Scale Sets Autoscaler
Storage Services
Service Type AWS Azure GCP
─────────────────────────────────────────────────────────────
Object Storage S3 Blob Storage Cloud Storage
Block Storage EBS Disk Storage Persistent Disk
File Storage EFS Files Filestore
Archive Glacier Archive Storage Archive Storage
Database Services
Service Type AWS Azure GCP
─────────────────────────────────────────────────────────────
Relational DB RDS SQL Database Cloud SQL
NoSQL Document DocumentDB Cosmos DB Firestore
NoSQL Key-Value DynamoDB Table Storage Datastore
In-Memory Cache ElastiCache Cache for Redis Memorystore
Data Warehouse Redshift Synapse Analytics BigQuery
Networking Services
Service Type AWS Azure GCP
─────────────────────────────────────────────────────────────
Virtual Network VPC Virtual Network VPC
Load Balancer ELB/ALB Load Balancer Cloud Load Bal.
CDN CloudFront CDN Cloud CDN
DNS Route 53 DNS Cloud DNS
VPN VPN Gateway VPN Gateway Cloud VPN
Cloud Architecture Patterns
1. Multi-Tier Architecture
┌─────────────────┐
│ Load Balancer │
└────────┬────────┘
│
┌───────────────────┼───────────────────┐
│ │ │
┌────▼────┐ ┌────▼────┐ ┌────▼────┐
│ Web │ │ Web │ │ Web │
│ Server │ │ Server │ │ Server │
└────┬────┘ └────┬────┘ └────┬────┘
│ │ │
└───────────────────┼──────────────────┘
│
┌────────▼────────┐
│ App Tier │
│ (Business │
│ Logic) │
└────────┬────────┘
│
┌────────▼────────┐
│ Database Tier │
│ (Primary + │
│ Replica) │
└─────────────────┘
2. Microservices Architecture
┌─────────┐ ┌──────────────────────────────────────────┐
│ API │───▶│ API Gateway │
│ Client │ └──────────┬───────────────────────────────┘
└─────────┘ │
│
┌─────────────────┼─────────────────┬─────────────┐
│ │ │ │
┌────▼─────┐ ┌────▼─────┐ ┌────▼─────┐ ┌────▼─────┐
│ User │ │ Product │ │ Order │ │ Payment │
│ Service │ │ Service │ │ Service │ │ Service │
└────┬─────┘ └────┬─────┘ └────┬─────┘ └────┬─────┘
│ │ │ │
┌────▼─────┐ ┌────▼─────┐ ┌────▼─────┐ ┌────▼─────┐
│ User DB │ │Product DB│ │ Order DB │ │Payment DB│
└──────────┘ └──────────┘ └──────────┘ └──────────┘
3. Event-Driven Architecture
┌──────────┐ ┌──────────────┐ ┌──────────────┐
│ Producer │─────▶│ Message │─────▶│ Consumer 1 │
│ Service │ │ Queue/Topic │ └──────────────┘
└──────────┘ │ (SQS/SNS/ │ ┌──────────────┐
│ EventBridge)│─────▶│ Consumer 2 │
└──────────────┘ └──────────────┘
│ ┌──────────────┐
└─────────────▶│ Consumer 3 │
└──────────────┘
4. Serverless Architecture
┌─────────┐ ┌──────────┐ ┌─────────────┐ ┌──────────┐
│ Client │───▶│ API │───▶│ Lambda │───▶│ Database │
│ │ │ Gateway │ │ Functions │ │ (DynamoDB│
└─────────┘ └──────────┘ └─────────────┘ │ /RDS) │
│ └──────────┘
│
▼
┌─────────────┐
│ Storage │
│ (S3) │
└─────────────┘
Cost Optimization
Pricing Models
1. On-Demand
- Pay for compute capacity by the hour/second
- No long-term commitments
- Best for: Short-term, unpredictable workloads
2. Reserved Instances
- 1 or 3-year commitment
- Up to 75% discount vs on-demand
- Best for: Steady-state workloads
3. Spot/Preemptible Instances
- Up to 90% discount vs on-demand
- Can be terminated with short notice
- Best for: Batch jobs, fault-tolerant workloads
4. Savings Plans
- Flexible pricing model
- Commitment to consistent usage
- Up to 72% discount
Cost Optimization Strategies
┌──────────────────────────────────────────────────────────┐
│ Cost Optimization Best Practices │
├──────────────────────────────────────────────────────────┤
│ 1. Right-sizing │
│ └─ Match instance types to actual needs │
│ │
│ 2. Auto-scaling │
│ └─ Scale resources based on demand │
│ │
│ 3. Reserved Instances │
│ └─ Commit to predictable workloads │
│ │
│ 4. Spot Instances │
│ └─ Use for fault-tolerant workloads │
│ │
│ 5. Storage Lifecycle Policies │
│ └─ Move data to cheaper tiers over time │
│ │
│ 6. Delete Unused Resources │
│ └─ Regular audits and cleanup │
│ │
│ 7. Use Serverless │
│ └─ Pay only for execution time │
│ │
│ 8. Monitor and Alert │
│ └─ Set up cost budgets and alerts │
└──────────────────────────────────────────────────────────┘
Monthly Cost Estimation Example
Service Configuration Monthly Cost (Approx)
─────────────────────────────────────────────────────────────────
EC2 (t3.medium) 730 hours on-demand $30
EBS (100 GB) General Purpose SSD $10
RDS (db.t3.small) PostgreSQL, 730 hours $25
S3 (100 GB) Standard storage $2.30
Data Transfer 50 GB outbound $4.50
─────────
Total: ~$72/month
Security Best Practices
1. Identity and Access Management (IAM)
Best Practices:
├─ Use principle of least privilege
├─ Enable Multi-Factor Authentication (MFA)
├─ Rotate credentials regularly
├─ Use roles instead of access keys when possible
├─ Implement password policies
└─ Audit permissions regularly
2. Network Security
┌─────────────────────────────────────────────┐
│ VPC Security │
├─────────────────────────────────────────────┤
│ │
│ Public Subnet │
│ ┌──────────────────────────────────┐ │
│ │ Load Balancer │ │
│ │ (Security Group: HTTP/HTTPS) │ │
│ └──────────────┬───────────────────┘ │
│ │ │
│ Private Subnet │ │
│ ┌──────────────▼───────────────┐ │
│ │ Application Servers │ │
│ │ (SG: From LB only) │ │
│ └──────────────┬───────────────┘ │
│ │ │
│ Database Subnet│ │
│ ┌──────────────▼───────────────┐ │
│ │ Database │ │
│ │ (SG: From App only) │ │
│ └──────────────────────────────┘ │
└─────────────────────────────────────────────┘
3. Data Protection
- Encryption at Rest: Enable for all storage services
- Encryption in Transit: Use TLS/SSL for all communications
- Backup and Recovery: Regular automated backups
- Data Classification: Tag and classify sensitive data
4. Monitoring and Logging
Security Monitoring Stack:
├─ CloudWatch/Azure Monitor - Metrics and logs
├─ CloudTrail/Activity Log - API call auditing
├─ GuardDuty/Defender - Threat detection
├─ Security Hub/Security Center - Compliance
└─ SIEM Integration - Centralized monitoring
5. Compliance
Common compliance frameworks:
- GDPR: European data protection
- HIPAA: Healthcare data
- PCI DSS: Payment card data
- SOC 2: Security and availability
- ISO 27001: Information security
Choosing a Cloud Provider
Decision Matrix
Factor Weight AWS Azure GCP
─────────────────────────────────────────────────
Existing Ecosystem High ★★★★ ★★★★★ ★★★
Services Breadth High ★★★★★ ★★★★★ ★★★★
Pricing Medium ★★★★ ★★★★ ★★★★★
Documentation Medium ★★★★★ ★★★★ ★★★★
Support Medium ★★★★ ★★★★★ ★★★
ML/AI Capabilities Varies ★★★★ ★★★★ ★★★★★
Kubernetes Varies ★★★★ ★★★★ ★★★★★
Global Reach High ★★★★★ ★★★★★ ★★★★
Use Case Recommendations
Choose AWS if:
- Need broadest service selection
- Want mature ecosystem and tooling
- Building greenfield applications
- Need strong serverless capabilities
Choose Azure if:
- Heavy Microsoft/Windows workloads
- Need hybrid cloud capabilities
- Enterprise Active Directory integration
- Existing Microsoft licensing
Choose GCP if:
- Focus on data analytics and ML
- Need best-in-class Kubernetes
- Want innovative technologies
- Prioritize BigQuery for analytics
Use Multi-Cloud if:
- Need to avoid vendor lock-in
- Want best-of-breed services
- Have compliance requirements
- Can manage the complexity
Getting Started
Learning Path
1. Fundamentals (1-2 weeks)
├─ Cloud concepts and terminology
├─ Choose a primary provider
└─ Complete free tier tutorial
2. Core Services (2-4 weeks)
├─ Compute (EC2/VMs)
├─ Storage (S3/Blob)
├─ Databases (RDS/SQL)
└─ Networking (VPC)
3. Advanced Topics (4-8 weeks)
├─ Security and IAM
├─ Monitoring and logging
├─ CI/CD pipelines
└─ Infrastructure as Code
4. Specialization (Ongoing)
├─ Serverless
├─ Containers and Kubernetes
├─ ML/AI services
└─ Cost optimization
Recommended Certifications
AWS:
- AWS Certified Solutions Architect - Associate
- AWS Certified Developer - Associate
- AWS Certified SysOps Administrator
Azure:
- Azure Fundamentals (AZ-900)
- Azure Administrator (AZ-104)
- Azure Solutions Architect (AZ-305)
GCP:
- Google Cloud Digital Leader
- Associate Cloud Engineer
- Professional Cloud Architect
Resources
Free Tiers
- AWS: 12 months free tier + always free services
- Azure: $200 credit for 30 days + always free services
- GCP: $300 credit for 90 days + always free services
Documentation
- AWS: https://docs.aws.amazon.com
- Azure: https://docs.microsoft.com/azure
- GCP: https://cloud.google.com/docs
Community
- AWS: r/aws, AWS Forums
- Azure: r/azure, Microsoft Tech Community
- GCP: r/googlecloud, Google Cloud Community
Tools
- Terraform: Multi-cloud IaC
- Ansible: Configuration management
- Kubernetes: Container orchestration
- Prometheus/Grafana: Monitoring
- Cost Management: CloudHealth, CloudCheckr
Next Steps: Choose a cloud provider and explore provider-specific documentation:
- AWS Documentation
- Azure Documentation
- Google Cloud Documentation
- Cloud Setup Guide - Getting started with cloud environments
Setup
Setup GPU instances
Make sure the hardisk size is at least 30GB
curl https://raw.githubusercontent.com/GoogleCloudPlatform/compute-gpu-installation/main/linux/install_gpu_driver.py --output install_gpu_driver.py
#if required Change driver version in the py file from (DRIVER_VERSION = "525.125.06") to 550.54.15
sed -i 's/525.125.06/550.54.15/' install_gpu_driver.py
#run the script
sudo apt install python3-venv python3-dev
sudo python3 install_gpu_driver.py
#verify the installation
nvidia-smi
#install pytorch
pip3 install torch torchvision torchaudio
#install cuda toolkit
sudo apt install nvidia-cuda-toolkit
nvcc --version
Swap file
sudo fallocate -l 32G /swapfile
sudo chmod 600 /swapfile
sudo mkswap /swapfile
sudo swapon /swapfile
Google Cloud
Image storage (per GB / month) $0.05
- Custom image storage is based on Archive Size (which will be less).
- Note: 10G is not enough to install.
Amazon Web Services (AWS)
Table of Contents
- Introduction
- AWS Global Infrastructure
- Getting Started
- Core Compute Services
- Storage Services
- Database Services
- Networking Services
- Serverless Services
- Container Services
- Security Services
- Monitoring and Management
- DevOps and CI/CD
- Machine Learning Services
- Architecture Examples
- Cost Optimization
- Best Practices
- CLI Reference
Introduction
Amazon Web Services (AWS) is the world's most comprehensive and broadly adopted cloud platform, offering over 200 fully featured services from data centers globally.
Key Advantages
- Market Leader: Largest market share (~32%)
- Mature Ecosystem: Launched in 2006
- Service Breadth: 200+ services
- Global Reach: 30+ regions, 90+ availability zones
- Innovation: Rapid release of new features
- Community: Largest developer community
AWS Account Structure
┌─────────────────────────────────────────────┐
│ AWS Organization (Root) │
├─────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ Production │ │ Development │ │
│ │ OU │ │ OU │ │
│ ├──────────────┤ ├──────────────┤ │
│ │ Account 1 │ │ Account 3 │ │
│ │ Account 2 │ │ Account 4 │ │
│ └──────────────┘ └──────────────┘ │
│ │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ Security │ │ Sandbox │ │
│ │ OU │ │ OU │ │
│ └──────────────┘ └──────────────┘ │
└─────────────────────────────────────────────┘
AWS Global Infrastructure
Hierarchy
Region
└─ Availability Zones (AZs)
└─ Data Centers
└─ Edge Locations (CloudFront CDN)
Key Concepts
Region: Geographic area with multiple AZs
- Examples: us-east-1 (Virginia), eu-west-1 (Ireland)
- Completely independent
- Data doesn't leave region unless explicitly configured
Availability Zone: Isolated data center(s) within a region
- 2-6 AZs per region
- Low-latency connections between AZs
- Physical separation for fault tolerance
Edge Location: CDN endpoint for CloudFront
- 400+ edge locations globally
- Caches content closer to users
Region Selection Criteria
Factor Consideration
──────────────────────────────────────────────
Latency Distance to users
Compliance Data residency laws
Services Not all services in all regions
Cost Pricing varies by region
Getting Started
AWS CLI Installation
# Install AWS CLI v2 (Linux/macOS)
curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
unzip awscliv2.zip
sudo ./aws/install
# Verify installation
aws --version
# Configure AWS CLI
aws configure
# Enter:
# - AWS Access Key ID
# - AWS Secret Access Key
# - Default region (e.g., us-east-1)
# - Default output format (json/yaml/text/table)
# Alternative: Use environment variables
export AWS_ACCESS_KEY_ID="your-access-key"
export AWS_SECRET_ACCESS_KEY="your-secret-key"
export AWS_DEFAULT_REGION="us-east-1"
# Or use AWS profiles
aws configure --profile production
aws s3 ls --profile production
AWS CLI Configuration Files
# View configuration
cat ~/.aws/config
# [default]
# region = us-east-1
# output = json
#
# [profile production]
# region = us-west-2
# output = yaml
cat ~/.aws/credentials
# [default]
# aws_access_key_id = AKIAIOSFODNN7EXAMPLE
# aws_secret_access_key = wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY
#
# [production]
# aws_access_key_id = AKIAI44QH8DHBEXAMPLE
# aws_secret_access_key = je7MtGbClwBF/2Zp9Utk/h3yCo8nvbEXAMPLEKEY
Basic AWS CLI Commands
# Get caller identity
aws sts get-caller-identity
# List all regions
aws ec2 describe-regions --output table
# List available services
aws help
# Get help for specific service
aws ec2 help
aws s3 help
Core Compute Services
Amazon EC2 (Elastic Compute Cloud)
Virtual servers in the cloud.
Instance Types
Category Type vCPU Memory Use Case
──────────────────────────────────────────────────────────────
General t3.micro 2 1 GB Development
Purpose t3.medium 2 4 GB Web servers
m5.large 2 8 GB Applications
Compute c5.large 2 4 GB Batch processing
Optimized c5.xlarge 4 8 GB High-performance
Memory r5.large 2 16 GB Databases
Optimized r5.xlarge 4 32 GB Caching
Storage i3.large 2 15.25 GB NoSQL databases
Optimized d2.xlarge 4 30.5 GB Data warehousing
GPU p3.2xlarge 8 61 GB ML training
Instances g4dn.xlarge 4 16 GB ML inference
EC2 Pricing Models
Model Discount Commitment Use Case
─────────────────────────────────────────────────────────────
On-Demand Baseline None Unpredictable
Reserved Instance Up to 75% 1-3 years Steady state
Spot Instance Up to 90% None Fault-tolerant
Savings Plan Up to 72% 1-3 years Flexible
EC2 CLI Examples
# List all instances
aws ec2 describe-instances
# List instances with specific state
aws ec2 describe-instances \
--filters "Name=instance-state-name,Values=running" \
--query 'Reservations[].Instances[].[InstanceId,InstanceType,State.Name,PublicIpAddress]' \
--output table
# Launch an instance
aws ec2 run-instances \
--image-id ami-0c55b159cbfafe1f0 \
--instance-type t3.micro \
--key-name my-key-pair \
--security-group-ids sg-0123456789abcdef0 \
--subnet-id subnet-0123456789abcdef0 \
--tag-specifications 'ResourceType=instance,Tags=[{Key=Name,Value=MyWebServer}]'
# Stop an instance
aws ec2 stop-instances --instance-ids i-1234567890abcdef0
# Start an instance
aws ec2 start-instances --instance-ids i-1234567890abcdef0
# Terminate an instance
aws ec2 terminate-instances --instance-ids i-1234567890abcdef0
# Create AMI from instance
aws ec2 create-image \
--instance-id i-1234567890abcdef0 \
--name "MyWebServer-Backup-$(date +%Y%m%d)" \
--description "Backup of MyWebServer"
# List AMIs
aws ec2 describe-images --owners self
# Get instance metadata (from within instance)
curl http://169.254.169.254/latest/meta-data/
curl http://169.254.169.254/latest/meta-data/instance-id
curl http://169.254.169.254/latest/meta-data/public-ipv4
User Data Script Example
#!/bin/bash
# User data script for EC2 instance initialization
# Update system
yum update -y
# Install Apache web server
yum install -y httpd
# Start Apache
systemctl start httpd
systemctl enable httpd
# Create simple web page
echo "<h1>Hello from EC2!</h1>" > /var/www/html/index.html
# Install CloudWatch agent
wget https://s3.amazonaws.com/amazoncloudwatch-agent/amazon_linux/amd64/latest/amazon-cloudwatch-agent.rpm
rpm -U ./amazon-cloudwatch-agent.rpm
Auto Scaling
Automatically adjust capacity to maintain performance and costs.
Auto Scaling Architecture
┌─────────────────────────────────────────────────────┐
│ Application Load Balancer │
└──────────────────────┬──────────────────────────────┘
│
┌──────────────────┼──────────────────┐
│ │ │
┌───▼────┐ ┌───▼────┐ ┌───▼────┐
│ EC2 │ │ EC2 │ │ EC2 │
│ (Min) │ │ (Curr) │ │ (Max) │
└────────┘ └────────┘ └────────┘
│ │ │
└──────────────────┼──────────────────┘
│
┌──────────▼──────────┐
│ Auto Scaling Group │
│ │
│ Min: 2 │
│ Desired: 3 │
│ Max: 10 │
│ │
│ Scale Up: CPU>70% │
│ Scale Down: CPU<30%│
└─────────────────────┘
Auto Scaling CLI Examples
# Create launch template
aws ec2 create-launch-template \
--launch-template-name my-template \
--version-description "Initial version" \
--launch-template-data '{
"ImageId": "ami-0c55b159cbfafe1f0",
"InstanceType": "t3.micro",
"KeyName": "my-key-pair",
"SecurityGroupIds": ["sg-0123456789abcdef0"]
}'
# Create Auto Scaling group
aws autoscaling create-auto-scaling-group \
--auto-scaling-group-name my-asg \
--launch-template "LaunchTemplateName=my-template,Version=1" \
--min-size 2 \
--max-size 10 \
--desired-capacity 3 \
--vpc-zone-identifier "subnet-12345,subnet-67890" \
--target-group-arns arn:aws:elasticloadbalancing:region:account-id:targetgroup/my-targets/73e2d6bc24d8a067 \
--health-check-type ELB \
--health-check-grace-period 300
# Create scaling policy (target tracking)
aws autoscaling put-scaling-policy \
--auto-scaling-group-name my-asg \
--policy-name cpu-target-tracking \
--policy-type TargetTrackingScaling \
--target-tracking-configuration '{
"PredefinedMetricSpecification": {
"PredefinedMetricType": "ASGAverageCPUUtilization"
},
"TargetValue": 70.0
}'
# Describe Auto Scaling groups
aws autoscaling describe-auto-scaling-groups \
--auto-scaling-group-names my-asg
# Update Auto Scaling group capacity
aws autoscaling set-desired-capacity \
--auto-scaling-group-name my-asg \
--desired-capacity 5
# Delete Auto Scaling group
aws autoscaling delete-auto-scaling-group \
--auto-scaling-group-name my-asg \
--force-delete
AWS Lambda (Serverless)
Run code without provisioning servers. Covered in detail in Serverless Services.
Storage Services
Amazon S3 (Simple Storage Service)
Object storage service with 99.999999999% (11 9's) durability.
S3 Storage Classes
Class Use Case Retrieval Cost
────────────────────────────────────────────────────────────────────────
Standard Frequently accessed Instant $$$
Intelligent-Tiering Unknown/changing patterns Instant $$+
Standard-IA Infrequently accessed Instant $$
One Zone-IA Non-critical, infrequent Instant $
Glacier Instant Archive, instant retrieval Instant $
Glacier Flexible Archive, min-hour retrieval Minutes-Hours ¢¢
Glacier Deep Archive Long-term archive (7-10yr) 12 hours ¢
S3 Architecture
┌─────────────────────────────────────────────┐
│ Bucket: my-application-bucket │
│ Region: us-east-1 │
├─────────────────────────────────────────────┤
│ │
│ /images/ │
│ ├─ logo.png │
│ └─ banner.jpg │
│ │
│ /documents/ │
│ ├─ report.pdf │
│ └─ invoice.xlsx │
│ │
│ /backups/ │
│ └─ database-backup-2024-01-01.sql │
│ │
│ Features: │
│ ├─ Versioning: Enabled │
│ ├─ Encryption: AES-256 │
│ ├─ Lifecycle: Move to Glacier after 90d │
│ ├─ Replication: Cross-region enabled │
│ └─ Access Logs: Enabled │
└─────────────────────────────────────────────┘
S3 CLI Examples
# Create bucket
aws s3 mb s3://my-unique-bucket-name-12345
# List buckets
aws s3 ls
# Upload file
aws s3 cp local-file.txt s3://my-bucket/
aws s3 cp local-file.txt s3://my-bucket/folder/
# Upload directory recursively
aws s3 cp ./my-directory s3://my-bucket/path/ --recursive
# Download file
aws s3 cp s3://my-bucket/file.txt ./
# Sync local directory with S3 (like rsync)
aws s3 sync ./local-dir s3://my-bucket/remote-dir/
aws s3 sync s3://my-bucket/remote-dir/ ./local-dir
# List objects in bucket
aws s3 ls s3://my-bucket/
aws s3 ls s3://my-bucket/folder/ --recursive
# Delete object
aws s3 rm s3://my-bucket/file.txt
# Delete all objects in folder
aws s3 rm s3://my-bucket/folder/ --recursive
# Make object public
aws s3api put-object-acl \
--bucket my-bucket \
--key file.txt \
--acl public-read
# Generate presigned URL (temporary access)
aws s3 presign s3://my-bucket/private-file.pdf --expires-in 3600
# Enable versioning
aws s3api put-bucket-versioning \
--bucket my-bucket \
--versioning-configuration Status=Enabled
# Enable server-side encryption
aws s3api put-bucket-encryption \
--bucket my-bucket \
--server-side-encryption-configuration '{
"Rules": [{
"ApplyServerSideEncryptionByDefault": {
"SSEAlgorithm": "AES256"
}
}]
}'
# Set lifecycle policy
aws s3api put-bucket-lifecycle-configuration \
--bucket my-bucket \
--lifecycle-configuration file://lifecycle.json
S3 Lifecycle Policy Example
{
"Rules": [
{
"Id": "MoveOldFilesToGlacier",
"Status": "Enabled",
"Filter": {
"Prefix": "logs/"
},
"Transitions": [
{
"Days": 30,
"StorageClass": "STANDARD_IA"
},
{
"Days": 90,
"StorageClass": "GLACIER"
}
],
"Expiration": {
"Days": 365
}
},
{
"Id": "DeleteOldVersions",
"Status": "Enabled",
"NoncurrentVersionExpiration": {
"NoncurrentDays": 30
}
}
]
}
S3 Bucket Policy Example
{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "PublicReadGetObject",
"Effect": "Allow",
"Principal": "*",
"Action": "s3:GetObject",
"Resource": "arn:aws:s3:::my-bucket/public/*"
},
{
"Sid": "DenyInsecureTransport",
"Effect": "Deny",
"Principal": "*",
"Action": "s3:*",
"Resource": [
"arn:aws:s3:::my-bucket",
"arn:aws:s3:::my-bucket/*"
],
"Condition": {
"Bool": {
"aws:SecureTransport": "false"
}
}
}
]
}
S3 SDK Example (Python/Boto3)
import boto3
from botocore.exceptions import ClientError
# Create S3 client
s3 = boto3.client('s3')
# Upload file
def upload_file(file_name, bucket, object_name=None):
if object_name is None:
object_name = file_name
try:
s3.upload_file(file_name, bucket, object_name)
print(f"Uploaded {file_name} to {bucket}/{object_name}")
except ClientError as e:
print(f"Error: {e}")
return False
return True
# Download file
def download_file(bucket, object_name, file_name):
try:
s3.download_file(bucket, object_name, file_name)
print(f"Downloaded {bucket}/{object_name} to {file_name}")
except ClientError as e:
print(f"Error: {e}")
return False
return True
# List objects
def list_objects(bucket, prefix=''):
try:
response = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
if 'Contents' in response:
for obj in response['Contents']:
print(f"{obj['Key']}: {obj['Size']} bytes")
except ClientError as e:
print(f"Error: {e}")
# Generate presigned URL
def create_presigned_url(bucket, object_name, expiration=3600):
try:
url = s3.generate_presigned_url(
'get_object',
Params={'Bucket': bucket, 'Key': object_name},
ExpiresIn=expiration
)
return url
except ClientError as e:
print(f"Error: {e}")
return None
# Usage
upload_file('local-file.txt', 'my-bucket', 'uploads/file.txt')
download_file('my-bucket', 'uploads/file.txt', 'downloaded-file.txt')
list_objects('my-bucket', 'uploads/')
url = create_presigned_url('my-bucket', 'uploads/file.txt')
print(f"Presigned URL: {url}")
Amazon EBS (Elastic Block Store)
Block storage for EC2 instances.
EBS Volume Types
Type IOPS Throughput Use Case Cost
────────────────────────────────────────────────────────────────────
gp3 3,000-16,000 125-1000 MB/s General purpose $$
gp2 3,000-16,000 Baseline General purpose $$
io2 64,000+ 1,000 MB/s Mission-critical DB $$$$
io1 32,000+ 500 MB/s High-performance DB $$$
st1 500 500 MB/s Big data, logs $
sc1 250 250 MB/s Cold data ¢
EBS CLI Examples
# Create EBS volume
aws ec2 create-volume \
--volume-type gp3 \
--size 100 \
--availability-zone us-east-1a \
--tag-specifications 'ResourceType=volume,Tags=[{Key=Name,Value=MyVolume}]'
# List volumes
aws ec2 describe-volumes
# Attach volume to instance
aws ec2 attach-volume \
--volume-id vol-0123456789abcdef0 \
--instance-id i-1234567890abcdef0 \
--device /dev/sdf
# Detach volume
aws ec2 detach-volume --volume-id vol-0123456789abcdef0
# Create snapshot
aws ec2 create-snapshot \
--volume-id vol-0123456789abcdef0 \
--description "Backup of MyVolume"
# List snapshots
aws ec2 describe-snapshots --owner-ids self
# Create volume from snapshot
aws ec2 create-volume \
--snapshot-id snap-0123456789abcdef0 \
--availability-zone us-east-1a
# Delete snapshot
aws ec2 delete-snapshot --snapshot-id snap-0123456789abcdef0
# Delete volume
aws ec2 delete-volume --volume-id vol-0123456789abcdef0
Amazon EFS (Elastic File System)
Managed NFS file system for EC2.
# Create EFS file system
aws efs create-file-system \
--performance-mode generalPurpose \
--throughput-mode bursting \
--encrypted \
--tags Key=Name,Value=MyEFS
# Create mount target
aws efs create-mount-target \
--file-system-id fs-0123456789abcdef0 \
--subnet-id subnet-0123456789abcdef0 \
--security-groups sg-0123456789abcdef0
# Mount EFS on EC2 instance
sudo mkdir /mnt/efs
sudo mount -t nfs4 -o nfsvers=4.1,rsize=1048576,wsize=1048576,hard,timeo=600,retrans=2 \
fs-0123456789abcdef0.efs.us-east-1.amazonaws.com:/ /mnt/efs
# Add to /etc/fstab for persistent mount
echo "fs-0123456789abcdef0.efs.us-east-1.amazonaws.com:/ /mnt/efs nfs4 defaults,_netdev 0 0" | sudo tee -a /etc/fstab
Database Services
Amazon RDS (Relational Database Service)
Managed relational databases.
Supported Engines
Engine Versions Use Case
───────────────────────────────────────────────────────
MySQL 5.7, 8.0 Web applications
PostgreSQL 11-15 Advanced features
MariaDB 10.3-10.6 MySQL alternative
Oracle 12c, 19c Enterprise apps
SQL Server 2016-2022 Microsoft stack
Amazon Aurora MySQL/PG compat High performance
RDS Architecture (Multi-AZ)
┌─────────────────────────────────────────────────┐
│ Application Servers │
└───────────────────┬─────────────────────────────┘
│
┌──────────▼──────────┐
│ RDS Endpoint │
│ (DNS CNAME) │
└──────────┬──────────┘
│
┌───────────────┼───────────────┐
│ │ │
┌───▼────┐ Sync Repl ┌───────▼─────┐
│Primary │◄──────────────►│ Standby │
│Instance│ │ Instance │
│(AZ-A) │ │ (AZ-B) │
└────────┘ └─────────────┘
│ │
│ Automatic Failover │
└────────────────────────────┘
RDS CLI Examples
# Create RDS instance
aws rds create-db-instance \
--db-instance-identifier mydb \
--db-instance-class db.t3.micro \
--engine postgres \
--engine-version 14.7 \
--master-username admin \
--master-user-password MySecurePassword123 \
--allocated-storage 20 \
--storage-type gp3 \
--vpc-security-group-ids sg-0123456789abcdef0 \
--db-subnet-group-name my-db-subnet-group \
--backup-retention-period 7 \
--preferred-backup-window "03:00-04:00" \
--preferred-maintenance-window "sun:04:00-sun:05:00" \
--multi-az \
--storage-encrypted \
--enable-cloudwatch-logs-exports '["postgresql"]'
# List RDS instances
aws rds describe-db-instances
# Get specific instance details
aws rds describe-db-instances \
--db-instance-identifier mydb \
--query 'DBInstances[0].[DBInstanceIdentifier,DBInstanceStatus,Endpoint.Address,Endpoint.Port]'
# Create read replica
aws rds create-db-instance-read-replica \
--db-instance-identifier mydb-replica \
--source-db-instance-identifier mydb \
--db-instance-class db.t3.micro \
--availability-zone us-east-1b
# Create snapshot
aws rds create-db-snapshot \
--db-instance-identifier mydb \
--db-snapshot-identifier mydb-snapshot-$(date +%Y%m%d)
# Restore from snapshot
aws rds restore-db-instance-from-db-snapshot \
--db-instance-identifier mydb-restored \
--db-snapshot-identifier mydb-snapshot-20240101
# Modify instance
aws rds modify-db-instance \
--db-instance-identifier mydb \
--db-instance-class db.t3.small \
--apply-immediately
# Stop instance (up to 7 days)
aws rds stop-db-instance --db-instance-identifier mydb
# Start instance
aws rds start-db-instance --db-instance-identifier mydb
# Delete instance
aws rds delete-db-instance \
--db-instance-identifier mydb \
--skip-final-snapshot
# Or with final snapshot:
# --final-db-snapshot-identifier mydb-final-snapshot
# Connect to RDS
psql -h mydb.c9akciq32.us-east-1.rds.amazonaws.com -U admin -d postgres
mysql -h mydb.c9akciq32.us-east-1.rds.amazonaws.com -u admin -p
Amazon DynamoDB
Fully managed NoSQL database.
DynamoDB Concepts
Table: Users
┌──────────────┬─────────────┬───────────┬─────────┬──────────┐
│ UserId (PK) │ Email (SK) │ Name │ Age │ Status │
├──────────────┼─────────────┼───────────┼─────────┼──────────┤
│ user-001 │ a@ex.com │ Alice │ 30 │ active │
│ user-002 │ b@ex.com │ Bob │ 25 │ active │
│ user-003 │ c@ex.com │ Charlie │ 35 │ inactive │
└──────────────┴─────────────┴───────────┴─────────┴──────────┘
PK = Partition Key (required, determines data distribution)
SK = Sort Key (optional, enables range queries)
DynamoDB Capacity Modes
Mode Billing Use Case Cost
─────────────────────────────────────────────────────────────
On-Demand Per request Unpredictable traffic $$$$
Provisioned Per hour Predictable traffic $$-$$$
+ Auto Per hour Variable patterns $$-$$$
Scaling
DynamoDB CLI Examples
# Create table
aws dynamodb create-table \
--table-name Users \
--attribute-definitions \
AttributeName=UserId,AttributeType=S \
AttributeName=Email,AttributeType=S \
--key-schema \
AttributeName=UserId,KeyType=HASH \
AttributeName=Email,KeyType=RANGE \
--billing-mode PAY_PER_REQUEST \
--tags Key=Environment,Value=Production
# List tables
aws dynamodb list-tables
# Describe table
aws dynamodb describe-table --table-name Users
# Put item
aws dynamodb put-item \
--table-name Users \
--item '{
"UserId": {"S": "user-001"},
"Email": {"S": "alice@example.com"},
"Name": {"S": "Alice"},
"Age": {"N": "30"},
"Status": {"S": "active"}
}'
# Get item
aws dynamodb get-item \
--table-name Users \
--key '{
"UserId": {"S": "user-001"},
"Email": {"S": "alice@example.com"}
}'
# Query items (by partition key)
aws dynamodb query \
--table-name Users \
--key-condition-expression "UserId = :userId" \
--expression-attribute-values '{
":userId": {"S": "user-001"}
}'
# Scan table (read all items - expensive!)
aws dynamodb scan --table-name Users
# Update item
aws dynamodb update-item \
--table-name Users \
--key '{
"UserId": {"S": "user-001"},
"Email": {"S": "alice@example.com"}
}' \
--update-expression "SET #status = :newStatus, Age = Age + :inc" \
--expression-attribute-names '{"#status": "Status"}' \
--expression-attribute-values '{
":newStatus": {"S": "inactive"},
":inc": {"N": "1"}
}'
# Delete item
aws dynamodb delete-item \
--table-name Users \
--key '{
"UserId": {"S": "user-001"},
"Email": {"S": "alice@example.com"}
}'
# Batch write
aws dynamodb batch-write-item --request-items file://batch-write.json
# Create global secondary index
aws dynamodb update-table \
--table-name Users \
--attribute-definitions AttributeName=Status,AttributeType=S \
--global-secondary-index-updates '[{
"Create": {
"IndexName": "StatusIndex",
"KeySchema": [{"AttributeName": "Status", "KeyType": "HASH"}],
"Projection": {"ProjectionType": "ALL"},
"ProvisionedThroughput": {
"ReadCapacityUnits": 5,
"WriteCapacityUnits": 5
}
}
}]'
# Enable Point-in-Time Recovery
aws dynamodb update-continuous-backups \
--table-name Users \
--point-in-time-recovery-specification PointInTimeRecoveryEnabled=true
DynamoDB SDK Example (Python/Boto3)
import boto3
from boto3.dynamodb.conditions import Key, Attr
from decimal import Decimal
# Create DynamoDB resource
dynamodb = boto3.resource('dynamodb')
table = dynamodb.Table('Users')
# Put item
def create_user(user_id, email, name, age):
response = table.put_item(
Item={
'UserId': user_id,
'Email': email,
'Name': name,
'Age': age,
'Status': 'active'
}
)
return response
# Get item
def get_user(user_id, email):
response = table.get_item(
Key={
'UserId': user_id,
'Email': email
}
)
return response.get('Item')
# Query by partition key
def get_user_emails(user_id):
response = table.query(
KeyConditionExpression=Key('UserId').eq(user_id)
)
return response['Items']
# Query with sort key condition
def get_user_by_email_prefix(user_id, email_prefix):
response = table.query(
KeyConditionExpression=Key('UserId').eq(user_id) &
Key('Email').begins_with(email_prefix)
)
return response['Items']
# Scan with filter
def get_active_users():
response = table.scan(
FilterExpression=Attr('Status').eq('active')
)
return response['Items']
# Update item
def update_user_status(user_id, email, new_status):
response = table.update_item(
Key={
'UserId': user_id,
'Email': email
},
UpdateExpression='SET #status = :status',
ExpressionAttributeNames={
'#status': 'Status'
},
ExpressionAttributeValues={
':status': new_status
},
ReturnValues='ALL_NEW'
)
return response['Attributes']
# Batch write
def batch_create_users(users):
with table.batch_writer() as batch:
for user in users:
batch.put_item(Item=user)
# Usage
create_user('user-001', 'alice@example.com', 'Alice', 30)
user = get_user('user-001', 'alice@example.com')
print(user)
emails = get_user_emails('user-001')
update_user_status('user-001', 'alice@example.com', 'inactive')
Amazon ElastiCache
Managed in-memory cache (Redis/Memcached).
# Create Redis cluster
aws elasticache create-cache-cluster \
--cache-cluster-id my-redis-cluster \
--cache-node-type cache.t3.micro \
--engine redis \
--engine-version 7.0 \
--num-cache-nodes 1 \
--cache-subnet-group-name my-cache-subnet-group \
--security-group-ids sg-0123456789abcdef0
# Create Redis replication group (cluster mode)
aws elasticache create-replication-group \
--replication-group-id my-redis-cluster \
--replication-group-description "My Redis cluster" \
--engine redis \
--cache-node-type cache.t3.micro \
--num-cache-clusters 3 \
--automatic-failover-enabled \
--multi-az-enabled
# Describe clusters
aws elasticache describe-cache-clusters \
--show-cache-node-info
# Get endpoint
aws elasticache describe-cache-clusters \
--cache-cluster-id my-redis-cluster \
--query 'CacheClusters[0].CacheNodes[0].Endpoint'
Networking Services
Amazon VPC (Virtual Private Cloud)
Isolated network for your AWS resources.
VPC Architecture
┌─────────────────────────────────────────────────────────────┐
│ VPC: 10.0.0.0/16 │
│ Region: us-east-1 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌───────────────────────────┐ ┌──────────────────────┐ │
│ │ Public Subnet (AZ-A) │ │ Public Subnet (AZ-B) │ │
│ │ 10.0.1.0/24 │ │ 10.0.2.0/24 │ │
│ │ │ │ │ │
│ │ ┌─────┐ ┌──────┐ │ │ ┌─────┐ ┌──────┐ │ │
│ │ │ NAT │ │ ALB │ │ │ │ NAT │ │ ALB │ │ │
│ │ └─────┘ └──────┘ │ │ └─────┘ └──────┘ │ │
│ └───────────┬───────────────┘ └──────────┬───────────┘ │
│ │ │ │
│ │ Internet Gateway │ │
│ └──────────────┬───────────────┘ │
│ │ │
│ ┌───────────────────────────┐ ┌──────────────────────┐ │
│ │ Private Subnet (AZ-A) │ │ Private Subnet (AZ-B)│ │
│ │ 10.0.11.0/24 │ │ 10.0.12.0/24 │ │
│ │ │ │ │ │
│ │ ┌─────┐ ┌─────┐ │ │ ┌─────┐ ┌─────┐ │ │
│ │ │ EC2 │ │ EC2 │ │ │ │ EC2 │ │ EC2 │ │ │
│ │ └─────┘ └─────┘ │ │ └─────┘ └─────┘ │ │
│ └───────────────────────────┘ └──────────────────────┘ │
│ │
│ ┌───────────────────────────┐ ┌──────────────────────┐ │
│ │ Database Subnet (AZ-A) │ │ Database Subnet (AZ-B│ │
│ │ 10.0.21.0/24 │ │ 10.0.22.0/24 │ │
│ │ │ │ │ │
│ │ ┌─────────┐ │ │ ┌─────────┐ │ │
│ │ │ RDS │ │ │ │ RDS │ │ │
│ │ └─────────┘ │ │ └─────────┘ │ │
│ └───────────────────────────┘ └──────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
VPC CLI Examples
# Create VPC
aws ec2 create-vpc \
--cidr-block 10.0.0.0/16 \
--tag-specifications 'ResourceType=vpc,Tags=[{Key=Name,Value=MyVPC}]'
# Create Internet Gateway
aws ec2 create-internet-gateway \
--tag-specifications 'ResourceType=internet-gateway,Tags=[{Key=Name,Value=MyIGW}]'
# Attach Internet Gateway to VPC
aws ec2 attach-internet-gateway \
--internet-gateway-id igw-0123456789abcdef0 \
--vpc-id vpc-0123456789abcdef0
# Create public subnet
aws ec2 create-subnet \
--vpc-id vpc-0123456789abcdef0 \
--cidr-block 10.0.1.0/24 \
--availability-zone us-east-1a \
--tag-specifications 'ResourceType=subnet,Tags=[{Key=Name,Value=PublicSubnet-AZ-A}]'
# Create private subnet
aws ec2 create-subnet \
--vpc-id vpc-0123456789abcdef0 \
--cidr-block 10.0.11.0/24 \
--availability-zone us-east-1a \
--tag-specifications 'ResourceType=subnet,Tags=[{Key=Name,Value=PrivateSubnet-AZ-A}]'
# Create route table
aws ec2 create-route-table \
--vpc-id vpc-0123456789abcdef0 \
--tag-specifications 'ResourceType=route-table,Tags=[{Key=Name,Value=PublicRouteTable}]'
# Create route to Internet Gateway
aws ec2 create-route \
--route-table-id rtb-0123456789abcdef0 \
--destination-cidr-block 0.0.0.0/0 \
--gateway-id igw-0123456789abcdef0
# Associate route table with subnet
aws ec2 associate-route-table \
--route-table-id rtb-0123456789abcdef0 \
--subnet-id subnet-0123456789abcdef0
# Create NAT Gateway (for private subnet internet access)
# First, allocate Elastic IP
aws ec2 allocate-address --domain vpc
# Create NAT Gateway in public subnet
aws ec2 create-nat-gateway \
--subnet-id subnet-0123456789abcdef0 \
--allocation-id eipalloc-0123456789abcdef0 \
--tag-specifications 'ResourceType=natgateway,Tags=[{Key=Name,Value=MyNATGateway}]'
# Create route to NAT Gateway for private subnet
aws ec2 create-route \
--route-table-id rtb-private-0123456789abcdef0 \
--destination-cidr-block 0.0.0.0/0 \
--nat-gateway-id nat-0123456789abcdef0
# Create security group
aws ec2 create-security-group \
--group-name web-server-sg \
--description "Security group for web servers" \
--vpc-id vpc-0123456789abcdef0
# Add inbound rules
aws ec2 authorize-security-group-ingress \
--group-id sg-0123456789abcdef0 \
--protocol tcp \
--port 80 \
--cidr 0.0.0.0/0
aws ec2 authorize-security-group-ingress \
--group-id sg-0123456789abcdef0 \
--protocol tcp \
--port 443 \
--cidr 0.0.0.0/0
aws ec2 authorize-security-group-ingress \
--group-id sg-0123456789abcdef0 \
--protocol tcp \
--port 22 \
--cidr 10.0.0.0/16
# List VPCs
aws ec2 describe-vpcs
# List subnets
aws ec2 describe-subnets --filters "Name=vpc-id,Values=vpc-0123456789abcdef0"
# List security groups
aws ec2 describe-security-groups --filters "Name=vpc-id,Values=vpc-0123456789abcdef0"
Elastic Load Balancing (ELB)
Distribute traffic across multiple targets.
Load Balancer Types
Type Use Case OSI Layer Cost
──────────────────────────────────────────────────────────────────
Application (ALB) HTTP/HTTPS, path routing Layer 7 $$
Network (NLB) TCP/UDP, ultra performance Layer 4 $$
Gateway (GWLB) Third-party appliances Layer 3 $$$
Classic (CLB) Legacy (deprecated) Layer 4/7 $
ALB Architecture
Internet
│
┌────────▼────────┐
│ Application │
│ Load Balancer │
│ (ALB) │
└────────┬────────┘
│
┌──────────────────┼──────────────────┐
│ │ │
┌─────▼─────┐ ┌─────▼─────┐ ┌─────▼─────┐
│ Target │ │ Target │ │ Target │
│ Group 1 │ │ Group 2 │ │ Group 3 │
│ │ │ │ │ │
│ /api/* │ │ /images/* │ │ /* │
└───────────┘ └───────────┘ └───────────┘
│ │ │
API Servers Image Service Web Servers
Load Balancer CLI Examples
# Create Application Load Balancer
aws elbv2 create-load-balancer \
--name my-alb \
--subnets subnet-0123456789abcdef0 subnet-0123456789abcdef1 \
--security-groups sg-0123456789abcdef0 \
--scheme internet-facing \
--type application \
--ip-address-type ipv4
# Create target group
aws elbv2 create-target-group \
--name my-targets \
--protocol HTTP \
--port 80 \
--vpc-id vpc-0123456789abcdef0 \
--health-check-path /health \
--health-check-interval-seconds 30 \
--health-check-timeout-seconds 5 \
--healthy-threshold-count 2 \
--unhealthy-threshold-count 2
# Register targets
aws elbv2 register-targets \
--target-group-arn arn:aws:elasticloadbalancing:region:account-id:targetgroup/my-targets/73e2d6bc24d8a067 \
--targets Id=i-1234567890abcdef0 Id=i-0987654321abcdef0
# Create listener
aws elbv2 create-listener \
--load-balancer-arn arn:aws:elasticloadbalancing:region:account-id:loadbalancer/app/my-alb/50dc6c495c0c9188 \
--protocol HTTP \
--port 80 \
--default-actions Type=forward,TargetGroupArn=arn:aws:elasticloadbalancing:region:account-id:targetgroup/my-targets/73e2d6bc24d8a067
# Create HTTPS listener with certificate
aws elbv2 create-listener \
--load-balancer-arn arn:aws:elasticloadbalancing:region:account-id:loadbalancer/app/my-alb/50dc6c495c0c9188 \
--protocol HTTPS \
--port 443 \
--certificates CertificateArn=arn:aws:acm:region:account-id:certificate/12345678-1234-1234-1234-123456789012 \
--default-actions Type=forward,TargetGroupArn=arn:aws:elasticloadbalancing:region:account-id:targetgroup/my-targets/73e2d6bc24d8a067
# Create path-based routing rule
aws elbv2 create-rule \
--listener-arn arn:aws:elasticloadbalancing:region:account-id:listener/app/my-alb/50dc6c495c0c9188/f2f7dc8efc522ab2 \
--priority 10 \
--conditions Field=path-pattern,Values='/api/*' \
--actions Type=forward,TargetGroupArn=arn:aws:elasticloadbalancing:region:account-id:targetgroup/api-targets/73e2d6bc24d8a067
# Describe load balancers
aws elbv2 describe-load-balancers
# Describe target health
aws elbv2 describe-target-health \
--target-group-arn arn:aws:elasticloadbalancing:region:account-id:targetgroup/my-targets/73e2d6bc24d8a067
Amazon Route 53
Scalable DNS and domain registration.
# List hosted zones
aws route53 list-hosted-zones
# Create hosted zone
aws route53 create-hosted-zone \
--name example.com \
--caller-reference $(date +%s)
# Create A record
aws route53 change-resource-record-sets \
--hosted-zone-id Z1234567890ABC \
--change-batch '{
"Changes": [{
"Action": "CREATE",
"ResourceRecordSet": {
"Name": "www.example.com",
"Type": "A",
"TTL": 300,
"ResourceRecords": [{"Value": "192.0.2.1"}]
}
}]
}'
# Create CNAME record
aws route53 change-resource-record-sets \
--hosted-zone-id Z1234567890ABC \
--change-batch '{
"Changes": [{
"Action": "CREATE",
"ResourceRecordSet": {
"Name": "blog.example.com",
"Type": "CNAME",
"TTL": 300,
"ResourceRecords": [{"Value": "www.example.com"}]
}
}]
}'
# Create alias record (to ALB)
aws route53 change-resource-record-sets \
--hosted-zone-id Z1234567890ABC \
--change-batch '{
"Changes": [{
"Action": "CREATE",
"ResourceRecordSet": {
"Name": "api.example.com",
"Type": "A",
"AliasTarget": {
"HostedZoneId": "Z35SXDOTRQ7X7K",
"DNSName": "my-alb-1234567890.us-east-1.elb.amazonaws.com",
"EvaluateTargetHealth": true
}
}
}]
}'
# Health check for failover
aws route53 create-health-check \
--health-check-config \
IPAddress=192.0.2.1,Port=80,Type=HTTP,ResourcePath=/health,RequestInterval=30,FailureThreshold=3
Serverless Services
AWS Lambda
Run code without managing servers.
Lambda Architecture
┌────────────────────────────────────────────────┐
│ Event Sources │
├────────────────────────────────────────────────┤
│ │
│ API Gateway │ S3 │ DynamoDB │ SQS │ EventBridge │
│ │
└──────────────┬──────────┬──────────┬───────────┘
│ │ │
┌────▼────┐┌────▼────┐┌───▼──────┐
│ Lambda ││ Lambda ││ Lambda │
│Function ││Function ││ Function │
│ 1 ││ 2 ││ 3 │
└────┬────┘└────┬────┘└────┬─────┘
│ │ │
┌────▼──────────▼──────────▼─────┐
│ Destinations │
│ │
│ DynamoDB │ S3 │ SNS │ SQS │
└──────────────────────────────────┘
Lambda Function Example (Python)
import json
import boto3
s3 = boto3.client('s3')
dynamodb = boto3.resource('dynamodb')
table = dynamodb.Table('MyTable')
def lambda_handler(event, context):
"""
Lambda function handler
Args:
event: Event data passed to the function
context: Runtime information
Returns:
Response object
"""
# Log the event
print(f"Event: {json.dumps(event)}")
# Example: Process S3 event
if 'Records' in event:
for record in event['Records']:
bucket = record['s3']['bucket']['name']
key = record['s3']['object']['key']
print(f"Processing {key} from {bucket}")
# Process the file
try:
response = s3.get_object(Bucket=bucket, Key=key)
content = response['Body'].read().decode('utf-8')
# Store metadata in DynamoDB
table.put_item(
Item={
'file_key': key,
'bucket': bucket,
'size': response['ContentLength'],
'content_type': response['ContentType']
}
)
return {
'statusCode': 200,
'body': json.dumps('Successfully processed file')
}
except Exception as e:
print(f"Error: {str(e)}")
return {
'statusCode': 500,
'body': json.dumps(f'Error processing file: {str(e)}')
}
# Example: Process API Gateway event
if 'httpMethod' in event:
http_method = event['httpMethod']
path = event['path']
if http_method == 'GET' and path == '/items':
# Retrieve items from DynamoDB
response = table.scan()
return {
'statusCode': 200,
'headers': {
'Content-Type': 'application/json',
'Access-Control-Allow-Origin': '*'
},
'body': json.dumps(response['Items'])
}
elif http_method == 'POST' and path == '/items':
# Create new item
body = json.loads(event['body'])
table.put_item(Item=body)
return {
'statusCode': 201,
'headers': {
'Content-Type': 'application/json',
'Access-Control-Allow-Origin': '*'
},
'body': json.dumps({'message': 'Item created'})
}
return {
'statusCode': 400,
'body': json.dumps('Invalid request')
}
Lambda CLI Examples
# Create Lambda function
zip function.zip lambda_function.py
aws lambda create-function \
--function-name my-function \
--runtime python3.11 \
--role arn:aws:iam::123456789012:role/lambda-execution-role \
--handler lambda_function.lambda_handler \
--zip-file fileb://function.zip \
--timeout 30 \
--memory-size 256 \
--environment Variables={ENV=production,DB_TABLE=MyTable}
# Update function code
aws lambda update-function-code \
--function-name my-function \
--zip-file fileb://function.zip
# Update function configuration
aws lambda update-function-configuration \
--function-name my-function \
--timeout 60 \
--memory-size 512
# Invoke function synchronously
aws lambda invoke \
--function-name my-function \
--payload '{"key": "value"}' \
response.json
cat response.json
# Invoke function asynchronously
aws lambda invoke \
--function-name my-function \
--invocation-type Event \
--payload '{"key": "value"}' \
response.json
# List functions
aws lambda list-functions
# Get function details
aws lambda get-function --function-name my-function
# Add S3 trigger
aws lambda add-permission \
--function-name my-function \
--statement-id s3-invoke \
--action lambda:InvokeFunction \
--principal s3.amazonaws.com \
--source-arn arn:aws:s3:::my-bucket
aws s3api put-bucket-notification-configuration \
--bucket my-bucket \
--notification-configuration '{
"LambdaFunctionConfigurations": [{
"LambdaFunctionArn": "arn:aws:lambda:region:account-id:function:my-function",
"Events": ["s3:ObjectCreated:*"],
"Filter": {
"Key": {
"FilterRules": [{
"Name": "prefix",
"Value": "uploads/"
}]
}
}
}]
}'
# View logs
aws logs tail /aws/lambda/my-function --follow
# Create layer
zip layer.zip -r python/
aws lambda publish-layer-version \
--layer-name my-layer \
--description "Common dependencies" \
--zip-file fileb://layer.zip \
--compatible-runtimes python3.11
# Add layer to function
aws lambda update-function-configuration \
--function-name my-function \
--layers arn:aws:lambda:region:account-id:layer:my-layer:1
# Delete function
aws lambda delete-function --function-name my-function
Lambda Pricing
Component Price (us-east-1)
─────────────────────────────────────────────────
Requests $0.20 per 1M requests
Duration (x86) $0.0000166667 per GB-second
Duration (ARM/Graviton) $0.0000133334 per GB-second
Free Tier 1M requests + 400,000 GB-seconds/month
Example: 1 million requests, 512 MB, 1 second each
= 1M * $0.20/1M = $0.20 (requests)
+ 1M * 0.5 GB * 1 sec * $0.0000166667 = $8.33 (duration)
= $8.53/month (minus free tier)
API Gateway
Create, publish, and manage APIs.
# Create REST API
aws apigateway create-rest-api \
--name "My API" \
--description "My REST API" \
--endpoint-configuration types=REGIONAL
# Get root resource
aws apigateway get-resources \
--rest-api-id abc123
# Create resource
aws apigateway create-resource \
--rest-api-id abc123 \
--parent-id xyz789 \
--path-part items
# Create method
aws apigateway put-method \
--rest-api-id abc123 \
--resource-id uvw456 \
--http-method GET \
--authorization-type NONE
# Create Lambda integration
aws apigateway put-integration \
--rest-api-id abc123 \
--resource-id uvw456 \
--http-method GET \
--type AWS_PROXY \
--integration-http-method POST \
--uri arn:aws:apigateway:region:lambda:path/2015-03-31/functions/arn:aws:lambda:region:account-id:function:my-function/invocations
# Deploy API
aws apigateway create-deployment \
--rest-api-id abc123 \
--stage-name prod
# API URL format:
# https://abc123.execute-api.region.amazonaws.com/prod/items
# Enable API key
aws apigateway create-api-key \
--name "My API Key" \
--enabled
# Create usage plan
aws apigateway create-usage-plan \
--name "Basic Plan" \
--throttle burstLimit=100,rateLimit=50 \
--quota limit=10000,period=MONTH
# Associate API key with usage plan
aws apigateway create-usage-plan-key \
--usage-plan-id def456 \
--key-id ghi789 \
--key-type API_KEY
Container Services
Amazon ECS (Elastic Container Service)
Container orchestration service.
ECS Architecture
┌─────────────────────────────────────────────────┐
│ ECS Cluster │
├─────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────┐ │
│ │ ECS Service │ │
│ │ (Desired Count: 3) │ │
│ └──────────┬──────────────────────────────┘ │
│ │ │
│ ┌────────┼────────┐ │
│ │ │ │ │
│ ┌──▼──┐ ┌──▼──┐ ┌──▼──┐ │
│ │Task │ │Task │ │Task │ │
│ │ 1 │ │ 2 │ │ 3 │ │
│ └──┬──┘ └──┬──┘ └──┬──┘ │
│ │ │ │ │
│ ┌──▼───────▼────────▼───┐ │
│ │ Container(s) │ │
│ │ ┌────────────────┐ │ │
│ │ │ nginx:latest │ │ │
│ │ └────────────────┘ │ │
│ └────────────────────────┘ │
│ │
│ Launch Type: EC2 or Fargate │
└─────────────────────────────────────────────────┘
ECS Task Definition Example
{
"family": "web-app",
"networkMode": "awsvpc",
"requiresCompatibilities": ["FARGATE"],
"cpu": "256",
"memory": "512",
"executionRoleArn": "arn:aws:iam::123456789012:role/ecsTaskExecutionRole",
"containerDefinitions": [
{
"name": "nginx",
"image": "nginx:latest",
"portMappings": [
{
"containerPort": 80,
"protocol": "tcp"
}
],
"essential": true,
"environment": [
{
"name": "ENV",
"value": "production"
}
],
"logConfiguration": {
"logDriver": "awslogs",
"options": {
"awslogs-group": "/ecs/web-app",
"awslogs-region": "us-east-1",
"awslogs-stream-prefix": "nginx"
}
}
}
]
}
ECS CLI Examples
# Create cluster (Fargate)
aws ecs create-cluster --cluster-name my-cluster
# Register task definition
aws ecs register-task-definition --cli-input-json file://task-definition.json
# Create service
aws ecs create-service \
--cluster my-cluster \
--service-name web-service \
--task-definition web-app:1 \
--desired-count 3 \
--launch-type FARGATE \
--network-configuration "awsvpcConfiguration={subnets=[subnet-12345,subnet-67890],securityGroups=[sg-12345],assignPublicIp=ENABLED}" \
--load-balancers "targetGroupArn=arn:aws:elasticloadbalancing:region:account-id:targetgroup/my-targets/73e2d6bc24d8a067,containerName=nginx,containerPort=80"
# List services
aws ecs list-services --cluster my-cluster
# Describe service
aws ecs describe-services \
--cluster my-cluster \
--services web-service
# Update service (e.g., change desired count)
aws ecs update-service \
--cluster my-cluster \
--service web-service \
--desired-count 5
# Run standalone task
aws ecs run-task \
--cluster my-cluster \
--task-definition web-app:1 \
--launch-type FARGATE \
--network-configuration "awsvpcConfiguration={subnets=[subnet-12345],securityGroups=[sg-12345],assignPublicIp=ENABLED}"
# View logs
aws logs tail /ecs/web-app --follow
# Stop task
aws ecs stop-task \
--cluster my-cluster \
--task arn:aws:ecs:region:account-id:task/my-cluster/abc123
# Delete service
aws ecs delete-service \
--cluster my-cluster \
--service web-service \
--force
# Delete cluster
aws ecs delete-cluster --cluster my-cluster
Amazon EKS (Elastic Kubernetes Service)
Managed Kubernetes service.
# Create EKS cluster (using eksctl - easier)
eksctl create cluster \
--name my-cluster \
--region us-east-1 \
--nodegroup-name standard-workers \
--node-type t3.medium \
--nodes 3 \
--nodes-min 1 \
--nodes-max 4 \
--managed
# Or using AWS CLI (more complex)
aws eks create-cluster \
--name my-cluster \
--role-arn arn:aws:iam::123456789012:role/eks-service-role \
--resources-vpc-config subnetIds=subnet-12345,subnet-67890,securityGroupIds=sg-12345
# Update kubeconfig
aws eks update-kubeconfig --name my-cluster --region us-east-1
# Verify connection
kubectl get nodes
# Deploy application
kubectl apply -f deployment.yaml
# List clusters
aws eks list-clusters
# Describe cluster
aws eks describe-cluster --name my-cluster
# Delete cluster (eksctl)
eksctl delete cluster --name my-cluster
AWS Fargate
Serverless compute for containers (works with ECS and EKS).
Benefits:
- No EC2 instances to manage
- Pay only for resources used
- Automatic scaling
- Built-in security
Use Cases:
- Microservices
- Batch processing
- CI/CD tasks
- Event-driven applications
Security Services
AWS IAM (Identity and Access Management)
Control access to AWS resources.
IAM Concepts
┌─────────────────────────────────────────┐
│ AWS Account │
├─────────────────────────────────────────┤
│ │
│ Users Groups Roles │
│ ├─ Alice ├─ Developers ├─ EC2 │
│ ├─ Bob ├─ Admins ├─ Lambda│
│ └─ Charlie └─ Viewers └─ ECS │
│ │
│ Policies (JSON documents) │
│ ├─ Managed Policies (AWS/Custom) │
│ └─ Inline Policies │
└─────────────────────────────────────────┘
IAM Policy Example
{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "AllowS3ReadWrite",
"Effect": "Allow",
"Action": [
"s3:GetObject",
"s3:PutObject",
"s3:DeleteObject"
],
"Resource": "arn:aws:s3:::my-bucket/*"
},
{
"Sid": "AllowS3ListBucket",
"Effect": "Allow",
"Action": "s3:ListBucket",
"Resource": "arn:aws:s3:::my-bucket"
},
{
"Sid": "DenyInsecureTransport",
"Effect": "Deny",
"Action": "s3:*",
"Resource": [
"arn:aws:s3:::my-bucket",
"arn:aws:s3:::my-bucket/*"
],
"Condition": {
"Bool": {
"aws:SecureTransport": "false"
}
}
}
]
}
IAM CLI Examples
# Create user
aws iam create-user --user-name alice
# Create access key
aws iam create-access-key --user-name alice
# Create group
aws iam create-group --group-name developers
# Add user to group
aws iam add-user-to-group \
--user-name alice \
--group-name developers
# Create policy
aws iam create-policy \
--policy-name S3ReadWritePolicy \
--policy-document file://policy.json
# Attach policy to user
aws iam attach-user-policy \
--user-name alice \
--policy-arn arn:aws:iam::123456789012:policy/S3ReadWritePolicy
# Attach policy to group
aws iam attach-group-policy \
--group-name developers \
--policy-arn arn:aws:iam::aws:policy/PowerUserAccess
# Create role (for EC2)
aws iam create-role \
--role-name EC2-S3-Role \
--assume-role-policy-document '{
"Version": "2012-10-17",
"Statement": [{
"Effect": "Allow",
"Principal": {"Service": "ec2.amazonaws.com"},
"Action": "sts:AssumeRole"
}]
}'
# Attach policy to role
aws iam attach-role-policy \
--role-name EC2-S3-Role \
--policy-arn arn:aws:iam::aws:policy/AmazonS3ReadOnlyAccess
# Create instance profile
aws iam create-instance-profile \
--instance-profile-name EC2-S3-Profile
# Add role to instance profile
aws iam add-role-to-instance-profile \
--instance-profile-name EC2-S3-Profile \
--role-name EC2-S3-Role
# Associate instance profile with EC2
aws ec2 associate-iam-instance-profile \
--instance-id i-1234567890abcdef0 \
--iam-instance-profile Name=EC2-S3-Profile
# List users
aws iam list-users
# List policies attached to user
aws iam list-attached-user-policies --user-name alice
# Delete user (must remove from groups and detach policies first)
aws iam remove-user-from-group --user-name alice --group-name developers
aws iam detach-user-policy --user-name alice --policy-arn arn:aws:iam::123456789012:policy/S3ReadWritePolicy
aws iam delete-user --user-name alice
AWS Secrets Manager
Store and rotate secrets.
# Create secret
aws secretsmanager create-secret \
--name prod/db/password \
--description "Database password for production" \
--secret-string '{"username":"admin","password":"MySecurePassword123"}'
# Get secret value
aws secretsmanager get-secret-value --secret-id prod/db/password
# Update secret
aws secretsmanager update-secret \
--secret-id prod/db/password \
--secret-string '{"username":"admin","password":"NewPassword456"}'
# Enable automatic rotation
aws secretsmanager rotate-secret \
--secret-id prod/db/password \
--rotation-lambda-arn arn:aws:lambda:region:account-id:function:my-rotation-function \
--rotation-rules AutomaticallyAfterDays=30
# Delete secret (with recovery window)
aws secretsmanager delete-secret \
--secret-id prod/db/password \
--recovery-window-in-days 30
Use Secret in Lambda (Python)
import boto3
import json
def get_secret(secret_name):
client = boto3.client('secretsmanager')
try:
response = client.get_secret_value(SecretId=secret_name)
secret = json.loads(response['SecretString'])
return secret
except Exception as e:
print(f"Error retrieving secret: {e}")
raise
def lambda_handler(event, context):
# Get database credentials
db_secret = get_secret('prod/db/password')
username = db_secret['username']
password = db_secret['password']
# Use credentials to connect to database
# ...
return {'statusCode': 200}
AWS KMS (Key Management Service)
Manage encryption keys.
# Create KMS key
aws kms create-key \
--description "Application data encryption key"
# Create alias
aws kms create-alias \
--alias-name alias/app-data-key \
--target-key-id 1234abcd-12ab-34cd-56ef-1234567890ab
# Encrypt data
aws kms encrypt \
--key-id alias/app-data-key \
--plaintext "sensitive data" \
--output text \
--query CiphertextBlob
# Decrypt data
aws kms decrypt \
--ciphertext-blob fileb://encrypted-data \
--output text \
--query Plaintext | base64 --decode
# List keys
aws kms list-keys
# Enable key rotation
aws kms enable-key-rotation --key-id 1234abcd-12ab-34cd-56ef-1234567890ab
Monitoring and Management
Amazon CloudWatch
Monitoring and observability service.
CloudWatch Metrics
# Put custom metric
aws cloudwatch put-metric-data \
--namespace "MyApp" \
--metric-name "RequestCount" \
--value 100 \
--timestamp $(date -u +"%Y-%m-%dT%H:%M:%S")
# Get metric statistics
aws cloudwatch get-metric-statistics \
--namespace AWS/EC2 \
--metric-name CPUUtilization \
--dimensions Name=InstanceId,Value=i-1234567890abcdef0 \
--start-time $(date -u -d '1 hour ago' +"%Y-%m-%dT%H:%M:%S") \
--end-time $(date -u +"%Y-%m-%dT%H:%M:%S") \
--period 300 \
--statistics Average
# Create alarm
aws cloudwatch put-metric-alarm \
--alarm-name high-cpu \
--alarm-description "Alert when CPU exceeds 80%" \
--metric-name CPUUtilization \
--namespace AWS/EC2 \
--statistic Average \
--period 300 \
--evaluation-periods 2 \
--threshold 80 \
--comparison-operator GreaterThanThreshold \
--dimensions Name=InstanceId,Value=i-1234567890abcdef0 \
--alarm-actions arn:aws:sns:region:account-id:my-topic
# List alarms
aws cloudwatch describe-alarms
# Delete alarm
aws cloudwatch delete-alarms --alarm-names high-cpu
CloudWatch Logs
# Create log group
aws logs create-log-group --log-group-name /aws/lambda/my-function
# Create log stream
aws logs create-log-stream \
--log-group-name /aws/lambda/my-function \
--log-stream-name 2024/01/01/instance-123
# Put log events
aws logs put-log-events \
--log-group-name /aws/lambda/my-function \
--log-stream-name 2024/01/01/instance-123 \
--log-events timestamp=$(date +%s000),message="Application started"
# Tail logs
aws logs tail /aws/lambda/my-function --follow
# Filter logs
aws logs filter-log-events \
--log-group-name /aws/lambda/my-function \
--filter-pattern "ERROR" \
--start-time $(date -d '1 hour ago' +%s)000
# Create metric filter
aws logs put-metric-filter \
--log-group-name /aws/lambda/my-function \
--filter-name ErrorCount \
--filter-pattern "[ERROR]" \
--metric-transformations \
metricName=ErrorCount,metricNamespace=MyApp,metricValue=1
# Export logs to S3
aws logs create-export-task \
--log-group-name /aws/lambda/my-function \
--from $(date -d '1 day ago' +%s)000 \
--to $(date +%s)000 \
--destination my-logs-bucket \
--destination-prefix lambda-logs/
# Set retention policy
aws logs put-retention-policy \
--log-group-name /aws/lambda/my-function \
--retention-in-days 30
# Delete log group
aws logs delete-log-group --log-group-name /aws/lambda/my-function
AWS CloudTrail
Track user activity and API usage.
# Create trail
aws cloudtrail create-trail \
--name my-trail \
--s3-bucket-name my-cloudtrail-bucket
# Start logging
aws cloudtrail start-logging --name my-trail
# Lookup events
aws cloudtrail lookup-events \
--lookup-attributes AttributeKey=EventName,AttributeValue=RunInstances \
--max-results 10
# Get trail status
aws cloudtrail get-trail-status --name my-trail
# Stop logging
aws cloudtrail stop-logging --name my-trail
# Delete trail
aws cloudtrail delete-trail --name my-trail
DevOps and CI/CD
AWS CodeCommit
Git repository hosting.
# Create repository
aws codecommit create-repository \
--repository-name my-repo \
--repository-description "My application code"
# Clone repository
git clone https://git-codecommit.us-east-1.amazonaws.com/v1/repos/my-repo
# Or with SSH
git clone ssh://git-codecommit.us-east-1.amazonaws.com/v1/repos/my-repo
# List repositories
aws codecommit list-repositories
# Get repository details
aws codecommit get-repository --repository-name my-repo
# Delete repository
aws codecommit delete-repository --repository-name my-repo
AWS CodeBuild
Build and test code.
buildspec.yml Example
version: 0.2
phases:
install:
runtime-versions:
python: 3.11
commands:
- echo "Installing dependencies..."
- pip install -r requirements.txt
pre_build:
commands:
- echo "Running tests..."
- pytest tests/
- echo "Logging in to Amazon ECR..."
- aws ecr get-login-password --region $AWS_DEFAULT_REGION | docker login --username AWS --password-stdin $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com
build:
commands:
- echo "Building Docker image..."
- docker build -t $IMAGE_REPO_NAME:$IMAGE_TAG .
- docker tag $IMAGE_REPO_NAME:$IMAGE_TAG $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com/$IMAGE_REPO_NAME:$IMAGE_TAG
post_build:
commands:
- echo "Pushing Docker image..."
- docker push $AWS_ACCOUNT_ID.dkr.ecr.$AWS_DEFAULT_REGION.amazonaws.com/$IMAGE_REPO_NAME:$IMAGE_TAG
- echo "Build completed on `date`"
artifacts:
files:
- '**/*'
name: build-output
cache:
paths:
- '/root/.cache/pip/**/*'
# Create build project
aws codebuild create-project \
--name my-build-project \
--source type=CODECOMMIT,location=https://git-codecommit.us-east-1.amazonaws.com/v1/repos/my-repo \
--artifacts type=S3,location=my-build-artifacts-bucket \
--environment type=LINUX_CONTAINER,image=aws/codebuild/standard:5.0,computeType=BUILD_GENERAL1_SMALL \
--service-role arn:aws:iam::123456789012:role/codebuild-service-role
# Start build
aws codebuild start-build --project-name my-build-project
# Get build details
aws codebuild batch-get-builds --ids my-build-project:build-id
AWS CodeDeploy
Automate application deployments.
# Create application
aws deploy create-application \
--application-name my-app \
--compute-platform Server
# Create deployment group
aws deploy create-deployment-group \
--application-name my-app \
--deployment-group-name production \
--deployment-config-name CodeDeployDefault.OneAtATime \
--ec2-tag-filters Key=Environment,Value=Production,Type=KEY_AND_VALUE \
--service-role-arn arn:aws:iam::123456789012:role/CodeDeployServiceRole
# Create deployment
aws deploy create-deployment \
--application-name my-app \
--deployment-group-name production \
--s3-location bucket=my-deployments-bucket,key=app-v1.0.zip,bundleType=zip
# Get deployment status
aws deploy get-deployment --deployment-id d-ABCDEF123
AWS CodePipeline
Continuous delivery service.
Pipeline Structure
┌──────────────┐ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐
│ Source │───▶│ Build │───▶│ Test │───▶│ Deploy │
│ (CodeCommit) │ │ (CodeBuild) │ │ (CodeBuild) │ │ (CodeDeploy) │
└──────────────┘ └──────────────┘ └──────────────┘ └──────────────┘
# Create pipeline
aws codepipeline create-pipeline --cli-input-json file://pipeline.json
# Get pipeline details
aws codepipeline get-pipeline --name my-pipeline
# Start pipeline execution
aws codepipeline start-pipeline-execution --name my-pipeline
# Get pipeline state
aws codepipeline get-pipeline-state --name my-pipeline
Machine Learning Services
Amazon SageMaker
Build, train, and deploy ML models.
import boto3
import sagemaker
from sagemaker import get_execution_role
from sagemaker.sklearn import SKLearn
# Set up
role = get_execution_role()
session = sagemaker.Session()
bucket = session.default_bucket()
# Train model
sklearn_estimator = SKLearn(
entry_point='train.py',
role=role,
instance_type='ml.m5.xlarge',
framework_version='0.23-1',
hyperparameters={
'n_estimators': 100,
'max_depth': 5
}
)
sklearn_estimator.fit({'train': 's3://bucket/train-data'})
# Deploy model
predictor = sklearn_estimator.deploy(
initial_instance_count=1,
instance_type='ml.t2.medium'
)
# Make predictions
result = predictor.predict(data)
Amazon Rekognition
Image and video analysis.
import boto3
rekognition = boto3.client('rekognition')
# Detect labels in image
response = rekognition.detect_labels(
Image={'S3Object': {'Bucket': 'my-bucket', 'Name': 'image.jpg'}},
MaxLabels=10,
MinConfidence=75
)
for label in response['Labels']:
print(f"{label['Name']}: {label['Confidence']:.2f}%")
# Detect faces
response = rekognition.detect_faces(
Image={'S3Object': {'Bucket': 'my-bucket', 'Name': 'face.jpg'}},
Attributes=['ALL']
)
# Compare faces
response = rekognition.compare_faces(
SourceImage={'S3Object': {'Bucket': 'my-bucket', 'Name': 'source.jpg'}},
TargetImage={'S3Object': {'Bucket': 'my-bucket', 'Name': 'target.jpg'}},
SimilarityThreshold=80
)
Amazon Comprehend
Natural language processing.
import boto3
comprehend = boto3.client('comprehend')
text = "Amazon Web Services is a great cloud platform."
# Detect sentiment
sentiment = comprehend.detect_sentiment(Text=text, LanguageCode='en')
print(f"Sentiment: {sentiment['Sentiment']}")
# Detect entities
entities = comprehend.detect_entities(Text=text, LanguageCode='en')
for entity in entities['Entities']:
print(f"{entity['Text']}: {entity['Type']}")
# Detect key phrases
phrases = comprehend.detect_key_phrases(Text=text, LanguageCode='en')
for phrase in phrases['KeyPhrases']:
print(phrase['Text'])
Architecture Examples
Three-Tier Web Application
Internet
│
┌────────▼────────┐
│ CloudFront │ CDN
│ (Optional) │
└────────┬────────┘
│
┌────────▼────────┐
│ Route 53 │ DNS
└────────┬────────┘
│
┌────────────────────────────▼────────────────────────────────┐
│ VPC │
│ │
│ ┌───────────────────────────────────────────────────────┐ │
│ │ Public Subnet (AZ-A) Public Subnet (AZ-B) │ │
│ │ ┌─────────────────┐ ┌─────────────────┐ │ │
│ │ │ Application │ │ Application │ │ │
│ │ │ Load Balancer │ │ Load Balancer │ │ │
│ │ └────────┬────────┘ └────────┬────────┘ │ │
│ └───────────┼──────────────────────┼───────────────────┘ │
│ │ │ │
│ ┌───────────▼──────────────────────▼───────────────────┐ │
│ │ Private Subnet (AZ-A) Private Subnet (AZ-B) │ │
│ │ ┌─────────────┐ ┌─────────────┐ │ │
│ │ │ Auto Scaling│ │ Auto Scaling│ │ │
│ │ │ Group │ │ Group │ │ │
│ │ │ ┌───┐ ┌───┐ │ ┌───┐ ┌───┐ │ │
│ │ │ │EC2│ │EC2│ │ │EC2│ │EC2│ │ │
│ │ │ └─┬─┘ └─┬─┘ │ └─┬─┘ └─┬─┘ │ │
│ │ └────┼─────┼────────────┘────┼─────┼──────────────┘ │
│ │ │ │ │ │ │
│ │ ┌────▼─────▼─────────────────▼─────▼──────────────┐ │
│ │ │ Database Subnet (AZ-A) Database Subnet (AZ-B)│ │
│ │ │ ┌──────────────┐ ┌──────────────┐ │ │
│ │ │ │ RDS Primary │◄────────▶│ RDS Standby │ │ │
│ │ │ └──────────────┘ └──────────────┘ │ │
│ │ │ │ │
│ │ │ ┌──────────────┐ │ │
│ │ │ │ ElastiCache │ │ │
│ │ │ └──────────────┘ │ │
│ │ └──────────────────────────────────────────────────┘ │
│ │ │
│ │ Additional Services: │
│ │ ├─ S3: Static assets │
│ │ ├─ CloudWatch: Monitoring │
│ │ ├─ CloudTrail: Audit logs │
│ │ └─ WAF: Web application firewall │
└──────────────────────────────────────────────────────────────┘
Serverless Microservices
┌─────────────┐
│ Users │
└──────┬──────┘
│
┌────────▼────────┐
│ CloudFront + │
│ S3 (Frontend) │
└────────┬────────┘
│
┌────────▼────────┐
│ API Gateway │
└────────┬────────┘
│
┌──────────────────────┼──────────────────────┐
│ │ │
┌────▼─────┐ ┌────▼─────┐ ┌────▼─────┐
│ Lambda │ │ Lambda │ │ Lambda │
│ User Svc │ │ Order Svc│ │ Pay Svc │
└────┬─────┘ └────┬─────┘ └────┬─────┘
│ │ │
┌────▼─────┐ ┌────▼─────┐ ┌────▼─────┐
│DynamoDB │ │DynamoDB │ │DynamoDB │
│Users │ │Orders │ │Payments │
└──────────┘ └──────────┘ └──────────┘
│ │ │
└─────────────────────┼─────────────────────┘
│
┌──────▼──────┐
│ EventBridge│
│ SNS │
└─────────────┘
Cost Optimization
Cost Optimization Strategies
┌──────────────────────────────────────────────────────────┐
│ AWS Cost Optimization Checklist │
├──────────────────────────────────────────────────────────┤
│ │
│ Compute │
│ ☐ Use Reserved Instances for steady workloads │
│ ☐ Use Spot Instances for fault-tolerant workloads │
│ ☐ Right-size instances based on metrics │
│ ☐ Use Savings Plans for flexible commitments │
│ ☐ Stop development/test instances off-hours │
│ ☐ Use Lambda/Fargate for serverless workloads │
│ ☐ Enable EC2 Auto Scaling │
│ │
│ Storage │
│ ☐ Use S3 Lifecycle policies │
│ ☐ Move infrequent data to S3-IA or Glacier │
│ ☐ Delete unattached EBS volumes │
│ ☐ Delete old snapshots │
│ ☐ Use S3 Intelligent-Tiering │
│ ☐ Enable EBS volume encryption only when needed │
│ │
│ Database │
│ ☐ Use Aurora Serverless for variable workloads │
│ ☐ Stop RDS instances when not in use │
│ ☐ Use DynamoDB On-Demand for unpredictable traffic │
│ ☐ Use read replicas efficiently │
│ ☐ Right-size RDS instances │
│ │
│ Network │
│ ☐ Use CloudFront to reduce data transfer costs │
│ ☐ Use VPC endpoints to avoid NAT Gateway costs │
│ ☐ Consolidate data transfer within same region │
│ ☐ Use Direct Connect for high volume transfers │
│ │
│ Monitoring │
│ ☐ Set up AWS Budgets with alerts │
│ ☐ Use Cost Explorer to analyze spending │
│ ☐ Enable Cost Allocation Tags │
│ ☐ Use Trusted Advisor cost optimization checks │
│ ☐ Review AWS Cost Anomaly Detection │
└──────────────────────────────────────────────────────────┘
AWS Cost Management CLI
# Set up budget
aws budgets create-budget \
--account-id 123456789012 \
--budget '{
"BudgetName": "Monthly-Budget",
"BudgetLimit": {
"Amount": "1000",
"Unit": "USD"
},
"TimeUnit": "MONTHLY",
"BudgetType": "COST"
}' \
--notifications-with-subscribers '[{
"Notification": {
"NotificationType": "ACTUAL",
"ComparisonOperator": "GREATER_THAN",
"Threshold": 80,
"ThresholdType": "PERCENTAGE"
},
"Subscribers": [{
"SubscriptionType": "EMAIL",
"Address": "email@example.com"
}]
}]'
# Get cost and usage
aws ce get-cost-and-usage \
--time-period Start=2024-01-01,End=2024-01-31 \
--granularity DAILY \
--metrics BlendedCost
# Get cost forecast
aws ce get-cost-forecast \
--time-period Start=2024-02-01,End=2024-02-29 \
--metric BLENDED_COST \
--granularity MONTHLY
Best Practices
Security Best Practices
1. Identity and Access
├─ Enable MFA for all users
├─ Use IAM roles instead of access keys
├─ Implement least privilege principle
├─ Rotate credentials regularly
└─ Use AWS SSO for centralized access
2. Network Security
├─ Use VPC with public/private subnets
├─ Implement security groups properly
├─ Use Network ACLs as additional layer
├─ Enable VPC Flow Logs
└─ Use AWS WAF for web applications
3. Data Protection
├─ Enable encryption at rest
├─ Use TLS/SSL for data in transit
├─ Regular backups and snapshots
├─ Enable versioning on S3
└─ Use KMS for key management
4. Monitoring and Logging
├─ Enable CloudTrail for all regions
├─ Use CloudWatch for monitoring
├─ Set up security alerts
├─ Regular security audits
└─ Use AWS Config for compliance
5. Incident Response
├─ Have incident response plan
├─ Use AWS Systems Manager
├─ Enable automated responses
└─ Regular disaster recovery drills
Performance Best Practices
1. Compute
├─ Choose appropriate instance types
├─ Use Auto Scaling
├─ Implement load balancing
├─ Consider serverless for variable workloads
└─ Use placement groups for HPC
2. Storage
├─ Use EBS-optimized instances
├─ Choose correct EBS volume type
├─ Use S3 Transfer Acceleration
├─ Implement caching (CloudFront, ElastiCache)
└─ Use S3 multipart upload
3. Database
├─ Use read replicas for read-heavy workloads
├─ Enable query caching
├─ Use connection pooling
├─ Implement proper indexing
└─ Consider Aurora for better performance
4. Network
├─ Use CloudFront CDN
├─ Enable enhanced networking
├─ Use VPC endpoints
├─ Implement Route 53 routing policies
└─ Consider Direct Connect
Reliability Best Practices
1. High Availability
├─ Deploy across multiple AZs
├─ Use Multi-AZ for databases
├─ Implement auto-scaling
├─ Use Elastic Load Balancing
└─ Consider multi-region for critical workloads
2. Backup and Recovery
├─ Automated backups for RDS
├─ Regular EBS snapshots
├─ Enable S3 versioning
├─ Cross-region replication
└─ Test recovery procedures
3. Monitoring
├─ Set up CloudWatch alarms
├─ Use health checks
├─ Monitor key metrics
├─ Implement automated responses
└─ Use AWS X-Ray for tracing
4. Testing
├─ Regular load testing
├─ Chaos engineering
├─ Failover testing
└─ Disaster recovery drills
CLI Reference
Common CLI Patterns
# Use --query for filtering output
aws ec2 describe-instances \
--query 'Reservations[].Instances[].[InstanceId,State.Name]' \
--output table
# Use --filters for filtering resources
aws ec2 describe-instances \
--filters "Name=instance-state-name,Values=running" \
"Name=tag:Environment,Values=production"
# Use --output for different formats
aws ec2 describe-instances --output json
aws ec2 describe-instances --output yaml
aws ec2 describe-instances --output table
aws ec2 describe-instances --output text
# Use JMESPath for complex queries
aws ec2 describe-instances \
--query 'Reservations[].Instances[?State.Name==`running`].[InstanceId,PrivateIpAddress]'
# Paginate results
aws s3api list-objects-v2 \
--bucket my-bucket \
--max-items 100 \
--page-size 10
# Wait for resource to be ready
aws ec2 wait instance-running --instance-ids i-1234567890abcdef0
# Generate skeleton for complex commands
aws ec2 run-instances --generate-cli-skeleton > template.json
# Edit template.json
aws ec2 run-instances --cli-input-json file://template.json
Useful Aliases
# Add to ~/.bashrc or ~/.zshrc
alias ec2-list='aws ec2 describe-instances --query "Reservations[].Instances[].[InstanceId,InstanceType,State.Name,PublicIpAddress,Tags[?Key=='\''Name'\''].Value|[0]]" --output table'
alias ec2-running='aws ec2 describe-instances --filters "Name=instance-state-name,Values=running" --query "Reservations[].Instances[].[InstanceId,InstanceType,PublicIpAddress]" --output table'
alias s3-buckets='aws s3 ls'
alias lambda-list='aws lambda list-functions --query "Functions[].[FunctionName,Runtime,LastModified]" --output table'
alias rds-list='aws rds describe-db-instances --query "DBInstances[].[DBInstanceIdentifier,DBInstanceStatus,Engine,DBInstanceClass]" --output table'
Certification Paths
AWS Certification Roadmap
Foundational
│
└─ AWS Certified Cloud Practitioner
│
├─ Associate Level
│ ├─ Solutions Architect Associate
│ ├─ Developer Associate
│ └─ SysOps Administrator Associate
│
└─ Professional Level
├─ Solutions Architect Professional
└─ DevOps Engineer Professional
Specialty (Optional)
├─ Security Specialty
├─ Machine Learning Specialty
├─ Advanced Networking Specialty
├─ Database Specialty
└─ Data Analytics Specialty
Resources
Official Documentation
- AWS Documentation: https://docs.aws.amazon.com
- AWS CLI Reference: https://awscli.amazonaws.com/v2/documentation/api/latest/reference/index.html
- AWS SDK Documentation: https://aws.amazon.com/tools/
Learning Resources
- AWS Training and Certification: https://aws.amazon.com/training/
- AWS Free Tier: https://aws.amazon.com/free/
- AWS Well-Architected Framework: https://aws.amazon.com/architecture/well-architected/
- AWS Samples: https://github.com/aws-samples
- AWS Workshops: https://workshops.aws/
Community
- r/aws: Reddit community
- AWS Forums: https://forums.aws.amazon.com/
- AWS re:Post: https://repost.aws/
- AWS User Groups: https://aws.amazon.com/developer/community/usergroups/
Tools
- AWS CLI: Command-line interface
- AWS SDKs: Python (Boto3), JavaScript, Java, .NET, etc.
- AWS CDK: Infrastructure as code using programming languages
- Terraform: Multi-cloud infrastructure as code
- LocalStack: Local AWS cloud emulator
Updated: January 2025
Microsoft Azure
Table of Contents
- Introduction
- Azure Global Infrastructure
- Getting Started
- Core Compute Services
- Storage Services
- Database Services
- Networking Services
- Serverless Services
- Container Services
- Security Services
- Monitoring and Management
- DevOps and CI/CD
- AI and Machine Learning
- Architecture Examples
- Azure vs AWS Comparison
- Cost Optimization
- Best Practices
- CLI Reference
Introduction
Microsoft Azure is a cloud computing platform providing 200+ services for building, deploying, and managing applications through Microsoft's global network of data centers.
Key Advantages
- Enterprise Integration: Seamless integration with Microsoft products (Office 365, Active Directory, Dynamics)
- Hybrid Cloud: Industry-leading hybrid cloud capabilities with Azure Arc
- Global Reach: 60+ regions (more than any other cloud provider)
- Compliance: Most comprehensive compliance offerings
- Windows Workloads: Best platform for .NET and Windows-based applications
- Developer Tools: Excellent integration with Visual Studio and GitHub
Azure Account Hierarchy
┌─────────────────────────────────────────────────┐
│ Azure Entra ID (Azure AD) Tenant │
│ (Organization-wide identity) │
└──────────────────┬──────────────────────────────┘
│
┌─────────▼─────────┐
│ Management Groups │
└─────────┬──────────┘
│
┌─────────▼─────────┐
│ Subscriptions │
│ ├─ Production │
│ ├─ Development │
│ └─ Testing │
└─────────┬──────────┘
│
┌─────────▼─────────┐
│ Resource Groups │
│ ├─ RG-Web │
│ ├─ RG-Database │
│ └─ RG-Network │
└─────────┬──────────┘
│
┌─────────▼─────────┐
│ Resources │
│ ├─ VMs │
│ ├─ Storage │
│ └─ Databases │
└────────────────────┘
Azure Global Infrastructure
Hierarchy
Geography (e.g., United States)
└─ Region (e.g., East US, West US)
└─ Availability Zones (3 per region)
└─ Data Centers
└─ Edge Locations (Azure Front Door)
Azure Regions
Azure has 60+ regions worldwide - more than any other cloud provider
Paired Regions: Each region is paired with another region for disaster recovery
- Example: East US ↔ West US
- Example: North Europe ↔ West Europe
Availability Zones
- 3 or more physically separate zones within a region
- Each zone has independent power, cooling, networking
- < 2ms latency between zones
- Not all regions have Availability Zones
Region Selection Criteria
Factor Consideration
────────────────────────────────────────────────
Latency Distance to users
Compliance Data residency requirements
Services Service availability varies
Cost Pricing differs by region
Paired Region Consider DR requirements
Getting Started
Azure CLI Installation
# Install Azure CLI (Linux)
curl -sL https://aka.ms/InstallAzureCLIDeb | sudo bash
# Install Azure CLI (macOS)
brew update && brew install azure-cli
# Install Azure CLI (Windows - PowerShell)
Invoke-WebRequest -Uri https://aka.ms/installazurecliwindows -OutFile .\AzureCLI.msi
Start-Process msiexec.exe -Wait -ArgumentList '/I AzureCLI.msi /quiet'
# Verify installation
az --version
# Login to Azure
az login
# Login with specific tenant
az login --tenant TENANT_ID
# Login with service principal
az login --service-principal \
--username APP_ID \
--password PASSWORD \
--tenant TENANT_ID
# Set default subscription
az account set --subscription "My Subscription"
# List subscriptions
az account list --output table
# Show current subscription
az account show
Azure PowerShell
# Install Azure PowerShell
Install-Module -Name Az -Repository PSGallery -Force
# Connect to Azure
Connect-AzAccount
# Set subscription
Set-AzContext -SubscriptionId "subscription-id"
# List subscriptions
Get-AzSubscription
# List resource groups
Get-AzResourceGroup
Basic Azure CLI Commands
# Get help
az help
az vm help
# List all resource groups
az group list --output table
# List all resources
az resource list --output table
# List available locations
az account list-locations --output table
# List available VM sizes
az vm list-sizes --location eastus --output table
# Interactive mode
az interactive
Core Compute Services
Azure Virtual Machines
Cloud-based virtual servers.
VM Series and Sizes
Series vCPU Memory Use Case AWS Equivalent
────────────────────────────────────────────────────────────────────────────
B-Series 1-20 0.5-80GB Burstable, dev/test t3
D-Series 2-96 8-384GB General purpose m5
F-Series 2-72 4-144GB Compute optimized c5
E-Series 2-96 16-672GB Memory optimized r5
M-Series 128-416 2-12TB Largest memory x1e
N-Series 6-24 112-448GB GPU instances p3/g4
VM Pricing Models
Model Discount Commitment Use Case
───────────────────────────────────────────────────────────────
Pay-as-you-go Baseline None Short-term
Reserved Instances Up to 72% 1-3 years Steady state
Spot VMs Up to 90% None Fault-tolerant
Azure Hybrid Benefit Up to 85% None Existing licenses
VM CLI Examples
# Create resource group
az group create \
--name myResourceGroup \
--location eastus
# List available VM images
az vm image list --output table
az vm image list --publisher MicrosoftWindowsServer --output table
# Create Linux VM
az vm create \
--resource-group myResourceGroup \
--name myVM \
--image Ubuntu2204 \
--size Standard_B2s \
--admin-username azureuser \
--generate-ssh-keys \
--public-ip-sku Standard \
--tags Environment=Production Owner=IT
# Create Windows VM
az vm create \
--resource-group myResourceGroup \
--name myWindowsVM \
--image Win2022Datacenter \
--size Standard_D2s_v3 \
--admin-username azureuser \
--admin-password 'SecurePassword123!'
# List VMs
az vm list --output table
# Get VM details
az vm show \
--resource-group myResourceGroup \
--name myVM \
--show-details
# Start VM
az vm start \
--resource-group myResourceGroup \
--name myVM
# Stop VM (deallocate to stop billing)
az vm deallocate \
--resource-group myResourceGroup \
--name myVM
# Restart VM
az vm restart \
--resource-group myResourceGroup \
--name myVM
# Resize VM
az vm resize \
--resource-group myResourceGroup \
--name myVM \
--size Standard_D4s_v3
# Delete VM
az vm delete \
--resource-group myResourceGroup \
--name myVM \
--yes
# Open port
az vm open-port \
--resource-group myResourceGroup \
--name myVM \
--port 80 \
--priority 1001
# Run command on VM
az vm run-command invoke \
--resource-group myResourceGroup \
--name myVM \
--command-id RunShellScript \
--scripts "sudo apt-get update && sudo apt-get install -y nginx"
# Create VM from snapshot
az vm create \
--resource-group myResourceGroup \
--name myRestoredVM \
--attach-os-disk myOSDisk \
--os-type Linux
# Get VM instance metadata (from within VM)
curl -H Metadata:true "http://169.254.169.254/metadata/instance?api-version=2021-02-01"
Custom Script Extension
# Add custom script extension (Linux)
az vm extension set \
--resource-group myResourceGroup \
--vm-name myVM \
--name customScript \
--publisher Microsoft.Azure.Extensions \
--settings '{"fileUris": ["https://raw.githubusercontent.com/user/repo/script.sh"],"commandToExecute": "./script.sh"}'
# Add custom script extension (Windows)
az vm extension set \
--resource-group myResourceGroup \
--vm-name myWindowsVM \
--name CustomScriptExtension \
--publisher Microsoft.Compute \
--settings '{"fileUris": ["https://example.com/script.ps1"],"commandToExecute": "powershell -ExecutionPolicy Unrestricted -File script.ps1"}'
Azure Virtual Machine Scale Sets (VMSS)
Auto-scaling groups of identical VMs.
VMSS Architecture
┌─────────────────────────────────────────────────┐
│ Azure Load Balancer │
└──────────────────┬──────────────────────────────┘
│
┌──────────┼──────────┐
│ │ │
┌───▼───┐ ┌───▼───┐ ┌───▼───┐
│ VM 1 │ │ VM 2 │ │ VM 3 │
└───────┘ └───────┘ └───────┘
│ │ │
└──────────┼──────────┘
│
┌──────────▼──────────┐
│ Virtual Machine │
│ Scale Set (VMSS) │
│ │
│ Min: 2 │
│ Current: 3 │
│ Max: 10 │
│ │
│ Scale Rules: │
│ CPU > 75%: +1 VM │
│ CPU < 25%: -1 VM │
└─────────────────────┘
VMSS CLI Examples
# Create VMSS
az vmss create \
--resource-group myResourceGroup \
--name myScaleSet \
--image Ubuntu2204 \
--instance-count 3 \
--vm-sku Standard_B2s \
--admin-username azureuser \
--generate-ssh-keys \
--load-balancer myLoadBalancer \
--upgrade-policy-mode Automatic
# List VMSS
az vmss list --output table
# Scale manually
az vmss scale \
--resource-group myResourceGroup \
--name myScaleSet \
--new-capacity 5
# Create autoscale profile
az monitor autoscale create \
--resource-group myResourceGroup \
--resource myScaleSet \
--resource-type Microsoft.Compute/virtualMachineScaleSets \
--name myAutoscaleProfile \
--min-count 2 \
--max-count 10 \
--count 3
# Create autoscale rule (scale out)
az monitor autoscale rule create \
--resource-group myResourceGroup \
--autoscale-name myAutoscaleProfile \
--condition "Percentage CPU > 75 avg 5m" \
--scale out 1
# Create autoscale rule (scale in)
az monitor autoscale rule create \
--resource-group myResourceGroup \
--autoscale-name myAutoscaleProfile \
--condition "Percentage CPU < 25 avg 5m" \
--scale in 1
# List VMSS instances
az vmss list-instances \
--resource-group myResourceGroup \
--name myScaleSet \
--output table
# Update VMSS image
az vmss update \
--resource-group myResourceGroup \
--name myScaleSet \
--set virtualMachineProfile.storageProfile.imageReference.version=latest
# Start rolling upgrade
az vmss update-instances \
--resource-group myResourceGroup \
--name myScaleSet \
--instance-ids '*'
# Delete VMSS
az vmss delete \
--resource-group myResourceGroup \
--name myScaleSet
Azure App Service
PaaS for web applications.
# Create App Service Plan
az appservice plan create \
--name myAppServicePlan \
--resource-group myResourceGroup \
--sku B1 \
--is-linux
# Create Web App
az webapp create \
--resource-group myResourceGroup \
--plan myAppServicePlan \
--name myUniqueWebApp123 \
--runtime "NODE:18-lts"
# Deploy from GitHub
az webapp deployment source config \
--name myUniqueWebApp123 \
--resource-group myResourceGroup \
--repo-url https://github.com/user/repo \
--branch main \
--manual-integration
# Deploy from local Git
az webapp deployment source config-local-git \
--name myUniqueWebApp123 \
--resource-group myResourceGroup
# Deploy ZIP file
az webapp deployment source config-zip \
--resource-group myResourceGroup \
--name myUniqueWebApp123 \
--src app.zip
# Set environment variables
az webapp config appsettings set \
--resource-group myResourceGroup \
--name myUniqueWebApp123 \
--settings DB_HOST=mydb.database.windows.net DB_NAME=mydb
# Enable HTTPS only
az webapp update \
--resource-group myResourceGroup \
--name myUniqueWebApp123 \
--https-only true
# Scale up (change plan)
az appservice plan update \
--name myAppServicePlan \
--resource-group myResourceGroup \
--sku P1V2
# Scale out (add instances)
az appservice plan update \
--name myAppServicePlan \
--resource-group myResourceGroup \
--number-of-workers 3
# View logs
az webapp log tail \
--resource-group myResourceGroup \
--name myUniqueWebApp123
# Restart web app
az webapp restart \
--resource-group myResourceGroup \
--name myUniqueWebApp123
# Delete web app
az webapp delete \
--resource-group myResourceGroup \
--name myUniqueWebApp123
Storage Services
Azure Blob Storage
Object storage service (equivalent to AWS S3).
Blob Storage Types
Type Use Case Performance Cost
────────────────────────────────────────────────────────────────────
Block Blobs Text and binary data Standard/Premium $$
Append Blobs Logging data Standard $$
Page Blobs VHD files, random access Premium $$$
Blob Access Tiers
Tier Access Frequency Retrieval Time Cost
─────────────────────────────────────────────────────────
Hot Frequent Immediate $$$
Cool Infrequent (30d+) Immediate $$
Cold Rare (90d+) Immediate $
Archive Rarely (180d+) Hours ¢
Blob Storage Architecture
┌─────────────────────────────────────────────────┐
│ Storage Account: mystorageaccount │
│ Location: eastus │
│ Replication: LRS/GRS/RA-GRS │
├─────────────────────────────────────────────────┤
│ │
│ Container: images (Blob Container) │
│ ├─ logo.png │
│ ├─ banner.jpg │
│ └─ photos/ │
│ ├─ photo1.jpg │
│ └─ photo2.jpg │
│ │
│ Container: documents │
│ ├─ report.pdf │
│ └─ invoice.xlsx │
│ │
│ Container: backups │
│ └─ database-backup.sql │
│ │
│ File Share: fileshare (Azure Files) │
│ ├─ shared/ │
│ └─ config/ │
│ │
│ Table Storage (NoSQL) │
│ Queue Storage (Message Queue) │
└─────────────────────────────────────────────────┘
Blob Storage CLI Examples
# Create storage account
az storage account create \
--name mystorageaccount123 \
--resource-group myResourceGroup \
--location eastus \
--sku Standard_LRS \
--kind StorageV2
# Get connection string
az storage account show-connection-string \
--name mystorageaccount123 \
--resource-group myResourceGroup
# Export connection string
export AZURE_STORAGE_CONNECTION_STRING="<connection-string>"
# Create container
az storage container create \
--name mycontainer \
--account-name mystorageaccount123 \
--public-access off
# Upload blob
az storage blob upload \
--container-name mycontainer \
--name myfile.txt \
--file ./local-file.txt \
--account-name mystorageaccount123
# Upload directory
az storage blob upload-batch \
--destination mycontainer \
--source ./local-directory \
--account-name mystorageaccount123
# Download blob
az storage blob download \
--container-name mycontainer \
--name myfile.txt \
--file ./downloaded-file.txt \
--account-name mystorageaccount123
# List blobs
az storage blob list \
--container-name mycontainer \
--account-name mystorageaccount123 \
--output table
# Copy blob
az storage blob copy start \
--source-container mycontainer \
--source-blob myfile.txt \
--destination-container backup \
--destination-blob myfile-backup.txt \
--account-name mystorageaccount123
# Generate SAS token
az storage blob generate-sas \
--container-name mycontainer \
--name myfile.txt \
--account-name mystorageaccount123 \
--permissions r \
--expiry 2024-12-31T23:59:59Z
# Set blob tier
az storage blob set-tier \
--container-name mycontainer \
--name myfile.txt \
--tier Cool \
--account-name mystorageaccount123
# Delete blob
az storage blob delete \
--container-name mycontainer \
--name myfile.txt \
--account-name mystorageaccount123
# Enable versioning
az storage account blob-service-properties update \
--account-name mystorageaccount123 \
--resource-group myResourceGroup \
--enable-versioning true
# Set lifecycle management policy
az storage account management-policy create \
--account-name mystorageaccount123 \
--resource-group myResourceGroup \
--policy @policy.json
Lifecycle Management Policy Example
{
"rules": [
{
"enabled": true,
"name": "MoveToArchive",
"type": "Lifecycle",
"definition": {
"actions": {
"baseBlob": {
"tierToCool": {
"daysAfterModificationGreaterThan": 30
},
"tierToArchive": {
"daysAfterModificationGreaterThan": 90
},
"delete": {
"daysAfterModificationGreaterThan": 365
}
}
},
"filters": {
"blobTypes": ["blockBlob"],
"prefixMatch": ["logs/"]
}
}
}
]
}
Blob Storage SDK Example (Python)
from azure.storage.blob import BlobServiceClient, BlobClient, ContainerClient
from azure.storage.blob import BlobSasPermissions, generate_blob_sas
from datetime import datetime, timedelta
# Create blob service client
connection_string = "DefaultEndpointsProtocol=https;AccountName=..."
blob_service_client = BlobServiceClient.from_connection_string(connection_string)
# Create container
def create_container(container_name):
container_client = blob_service_client.create_container(container_name)
return container_client
# Upload blob
def upload_blob(container_name, blob_name, data):
blob_client = blob_service_client.get_blob_client(
container=container_name,
blob=blob_name
)
blob_client.upload_blob(data, overwrite=True)
print(f"Uploaded {blob_name}")
# Upload file
def upload_file(container_name, file_path, blob_name=None):
if blob_name is None:
blob_name = file_path.split('/')[-1]
blob_client = blob_service_client.get_blob_client(
container=container_name,
blob=blob_name
)
with open(file_path, "rb") as data:
blob_client.upload_blob(data, overwrite=True)
print(f"Uploaded {file_path} as {blob_name}")
# Download blob
def download_blob(container_name, blob_name, file_path):
blob_client = blob_service_client.get_blob_client(
container=container_name,
blob=blob_name
)
with open(file_path, "wb") as file:
data = blob_client.download_blob()
file.write(data.readall())
print(f"Downloaded {blob_name} to {file_path}")
# List blobs
def list_blobs(container_name):
container_client = blob_service_client.get_container_client(container_name)
blob_list = container_client.list_blobs()
for blob in blob_list:
print(f"{blob.name}: {blob.size} bytes")
# Generate SAS URL
def generate_sas_url(container_name, blob_name, account_name, account_key):
sas_token = generate_blob_sas(
account_name=account_name,
container_name=container_name,
blob_name=blob_name,
account_key=account_key,
permission=BlobSasPermissions(read=True),
expiry=datetime.utcnow() + timedelta(hours=1)
)
url = f"https://{account_name}.blob.core.windows.net/{container_name}/{blob_name}?{sas_token}"
return url
# Delete blob
def delete_blob(container_name, blob_name):
blob_client = blob_service_client.get_blob_client(
container=container_name,
blob=blob_name
)
blob_client.delete_blob()
print(f"Deleted {blob_name}")
# Usage
create_container("mycontainer")
upload_file("mycontainer", "./local-file.txt")
download_blob("mycontainer", "local-file.txt", "./downloaded.txt")
list_blobs("mycontainer")
Azure Files
Managed SMB/NFS file shares.
# Create file share
az storage share create \
--name myfileshare \
--account-name mystorageaccount123 \
--quota 100
# Upload file to share
az storage file upload \
--share-name myfileshare \
--source ./local-file.txt \
--account-name mystorageaccount123
# List files
az storage file list \
--share-name myfileshare \
--account-name mystorageaccount123 \
--output table
# Mount file share (Linux)
sudo mkdir /mnt/azure
sudo mount -t cifs //mystorageaccount123.file.core.windows.net/myfileshare /mnt/azure \
-o vers=3.0,username=mystorageaccount123,password=<storage-key>,dir_mode=0777,file_mode=0777
# Mount file share (Windows)
net use Z: \\mystorageaccount123.file.core.windows.net\myfileshare /user:Azure\mystorageaccount123 <storage-key>
# Add to /etc/fstab (Linux)
echo "//mystorageaccount123.file.core.windows.net/myfileshare /mnt/azure cifs vers=3.0,username=mystorageaccount123,password=<storage-key>,dir_mode=0777,file_mode=0777 0 0" | sudo tee -a /etc/fstab
Azure Disk Storage
Managed disks for VMs (equivalent to AWS EBS).
Disk Types
Type IOPS Throughput Use Case Cost
─────────────────────────────────────────────────────────────────────
Ultra Disk 160K+ 4,000 MB/s Mission-critical $$$$
Premium SSD v2 80K 1,200 MB/s Production DBs $$$
Premium SSD 20K 900 MB/s Production $$
Standard SSD 6K 750 MB/s Web servers $
Standard HDD 2K 500 MB/s Backup, dev/test ¢
# Create managed disk
az disk create \
--resource-group myResourceGroup \
--name myDataDisk \
--size-gb 128 \
--sku Premium_LRS
# Attach disk to VM
az vm disk attach \
--resource-group myResourceGroup \
--vm-name myVM \
--name myDataDisk
# Detach disk
az vm disk detach \
--resource-group myResourceGroup \
--vm-name myVM \
--name myDataDisk
# Create snapshot
az snapshot create \
--resource-group myResourceGroup \
--name mySnapshot \
--source myDataDisk
# Create disk from snapshot
az disk create \
--resource-group myResourceGroup \
--name myRestoredDisk \
--source mySnapshot
# Increase disk size
az disk update \
--resource-group myResourceGroup \
--name myDataDisk \
--size-gb 256
Database Services
Azure SQL Database
Managed SQL Server database.
Service Tiers
Tier vCores Memory Max DB Size Use Case Cost
────────────────────────────────────────────────────────────────────────
Serverless 0.5-40 3-120GB 4TB Variable load $$
General 2-80 10.4-408GB 4TB Balanced $$
Purpose
Business 2-128 20.8-625GB 4TB Mission-critical $$$$
Critical
Hyperscale 2-128 20.8-625GB 100TB Large databases $$$
SQL Database CLI Examples
# Create SQL Server
az sql server create \
--name myuniquesqlserver123 \
--resource-group myResourceGroup \
--location eastus \
--admin-user sqladmin \
--admin-password 'SecurePassword123!'
# Configure firewall rule
az sql server firewall-rule create \
--resource-group myResourceGroup \
--server myuniquesqlserver123 \
--name AllowMyIP \
--start-ip-address 1.2.3.4 \
--end-ip-address 1.2.3.4
# Allow Azure services
az sql server firewall-rule create \
--resource-group myResourceGroup \
--server myuniquesqlserver123 \
--name AllowAzureServices \
--start-ip-address 0.0.0.0 \
--end-ip-address 0.0.0.0
# Create database
az sql db create \
--resource-group myResourceGroup \
--server myuniquesqlserver123 \
--name myDatabase \
--service-objective S0 \
--backup-storage-redundancy Local
# Create serverless database
az sql db create \
--resource-group myResourceGroup \
--server myuniquesqlserver123 \
--name myServerlessDB \
--edition GeneralPurpose \
--compute-model Serverless \
--family Gen5 \
--capacity 2 \
--auto-pause-delay 60
# List databases
az sql db list \
--resource-group myResourceGroup \
--server myuniquesqlserver123 \
--output table
# Scale database
az sql db update \
--resource-group myResourceGroup \
--server myuniquesqlserver123 \
--name myDatabase \
--service-objective S2
# Create read replica
az sql db replica create \
--name myDatabase \
--resource-group myResourceGroup \
--server myuniquesqlserver123 \
--partner-server myuniquesqlserver-replica \
--partner-resource-group myResourceGroup
# Create backup
az sql db export \
--resource-group myResourceGroup \
--server myuniquesqlserver123 \
--name myDatabase \
--admin-user sqladmin \
--admin-password 'SecurePassword123!' \
--storage-key-type StorageAccessKey \
--storage-key "<storage-key>" \
--storage-uri "https://mystorageaccount.blob.core.windows.net/backups/mydb.bacpac"
# Restore database
az sql db restore \
--resource-group myResourceGroup \
--server myuniquesqlserver123 \
--name myRestoredDB \
--source-database myDatabase \
--time "2024-01-01T00:00:00Z"
# Delete database
az sql db delete \
--resource-group myResourceGroup \
--server myuniquesqlserver123 \
--name myDatabase \
--yes
# Connect to database
sqlcmd -S myuniquesqlserver123.database.windows.net -d myDatabase -U sqladmin -P 'SecurePassword123!'
SQL Database Connection Example (Python)
import pyodbc
# Connection string
server = 'myuniquesqlserver123.database.windows.net'
database = 'myDatabase'
username = 'sqladmin'
password = 'SecurePassword123!'
driver = '{ODBC Driver 18 for SQL Server}'
connection_string = f'DRIVER={driver};SERVER={server};DATABASE={database};UID={username};PWD={password}'
# Connect to database
conn = pyodbc.connect(connection_string)
cursor = conn.cursor()
# Create table
cursor.execute('''
CREATE TABLE users (
id INT PRIMARY KEY IDENTITY,
name NVARCHAR(100),
email NVARCHAR(100),
created_at DATETIME DEFAULT GETDATE()
)
''')
# Insert data
cursor.execute("INSERT INTO users (name, email) VALUES (?, ?)", ('Alice', 'alice@example.com'))
conn.commit()
# Query data
cursor.execute("SELECT * FROM users")
rows = cursor.fetchall()
for row in rows:
print(f"ID: {row.id}, Name: {row.name}, Email: {row.email}")
# Close connection
cursor.close()
conn.close()
Azure Cosmos DB
Globally distributed NoSQL database.
Cosmos DB APIs
API Type Use Case AWS Equivalent
──────────────────────────────────────────────────────────────────────
Core (SQL) Document General purpose DynamoDB
MongoDB Document MongoDB compatibility DocumentDB
Cassandra Wide-column Cassandra workloads Keyspaces
Gremlin Graph Graph relationships Neptune
Table Key-value Simple key-value DynamoDB
Cosmos DB CLI Examples
# Create Cosmos DB account
az cosmosdb create \
--name mycosmosaccount123 \
--resource-group myResourceGroup \
--locations regionName=eastus failoverPriority=0 \
--locations regionName=westus failoverPriority=1 \
--default-consistency-level Session \
--enable-automatic-failover true
# Create database (SQL API)
az cosmosdb sql database create \
--account-name mycosmosaccount123 \
--resource-group myResourceGroup \
--name myDatabase
# Create container
az cosmosdb sql container create \
--account-name mycosmosaccount123 \
--resource-group myResourceGroup \
--database-name myDatabase \
--name myContainer \
--partition-key-path "/userId" \
--throughput 400
# Get connection string
az cosmosdb keys list \
--name mycosmosaccount123 \
--resource-group myResourceGroup \
--type connection-strings
# List databases
az cosmosdb sql database list \
--account-name mycosmosaccount123 \
--resource-group myResourceGroup
# Update throughput
az cosmosdb sql container throughput update \
--account-name mycosmosaccount123 \
--resource-group myResourceGroup \
--database-name myDatabase \
--name myContainer \
--throughput 1000
Cosmos DB SDK Example (Python)
from azure.cosmos import CosmosClient, PartitionKey, exceptions
# Initialize client
endpoint = "https://mycosmosaccount123.documents.azure.com:443/"
key = "<primary-key>"
client = CosmosClient(endpoint, key)
# Get database and container
database = client.get_database_client("myDatabase")
container = database.get_container_client("myContainer")
# Create item
item = {
'id': 'user-001',
'userId': 'user-001',
'name': 'Alice',
'email': 'alice@example.com',
'age': 30
}
container.create_item(body=item)
# Read item
item = container.read_item(item='user-001', partition_key='user-001')
print(item)
# Query items
query = "SELECT * FROM c WHERE c.age > @age"
parameters = [{"name": "@age", "value": 25}]
items = list(container.query_items(
query=query,
parameters=parameters,
enable_cross_partition_query=True
))
for item in items:
print(f"{item['name']}: {item['age']} years old")
# Update item
item['age'] = 31
container.replace_item(item='user-001', body=item)
# Delete item
container.delete_item(item='user-001', partition_key='user-001')
Azure Database for PostgreSQL/MySQL
Managed open-source databases.
# Create PostgreSQL server
az postgres flexible-server create \
--name mypostgresserver123 \
--resource-group myResourceGroup \
--location eastus \
--admin-user myadmin \
--admin-password 'SecurePassword123!' \
--sku-name Standard_B1ms \
--tier Burstable \
--storage-size 32
# Create MySQL server
az mysql flexible-server create \
--name mymysqlserver123 \
--resource-group myResourceGroup \
--location eastus \
--admin-user myadmin \
--admin-password 'SecurePassword123!' \
--sku-name Standard_B1ms \
--tier Burstable \
--storage-size 32
# Configure firewall
az postgres flexible-server firewall-rule create \
--resource-group myResourceGroup \
--name mypostgresserver123 \
--rule-name AllowMyIP \
--start-ip-address 1.2.3.4 \
--end-ip-address 1.2.3.4
# Connect to PostgreSQL
psql "host=mypostgresserver123.postgres.database.azure.com port=5432 dbname=postgres user=myadmin password=SecurePassword123! sslmode=require"
# Connect to MySQL
mysql -h mymysqlserver123.mysql.database.azure.com -u myadmin -p
Azure Cache for Redis
Managed Redis cache.
# Create Redis cache
az redis create \
--resource-group myResourceGroup \
--name myrediscache123 \
--location eastus \
--sku Basic \
--vm-size c0
# Get access keys
az redis list-keys \
--resource-group myResourceGroup \
--name myrediscache123
# Get hostname
az redis show \
--resource-group myResourceGroup \
--name myrediscache123 \
--query hostName
# Connect to Redis
redis-cli -h myrediscache123.redis.cache.windows.net -p 6380 -a <primary-key> --tls
Networking Services
Azure Virtual Network (VNet)
Isolated network (equivalent to AWS VPC).
VNet Architecture
┌─────────────────────────────────────────────────────────────┐
│ VNet: my-vnet (10.0.0.0/16) │
│ Region: eastus │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌───────────────────────┐ ┌──────────────────────────┐ │
│ │ Public Subnet │ │ Public Subnet │ │
│ │ 10.0.1.0/24 │ │ 10.0.2.0/24 │ │
│ │ (AZ 1) │ │ (AZ 2) │ │
│ │ │ │ │ │
│ │ ┌─────────────────┐ │ │ ┌─────────────────┐ │ │
│ │ │ Load Balancer │ │ │ │ Load Balancer │ │ │
│ │ └─────────────────┘ │ │ └─────────────────┘ │ │
│ └───────────┬───────────┘ └──────────┬───────────────┘ │
│ │ │ │
│ │ Azure Gateway │ │
│ └──────────┬──────────────┘ │
│ │ │
│ ┌───────────────────────┐ ┌──────────────────────────┐ │
│ │ Private Subnet │ │ Private Subnet │ │
│ │ 10.0.11.0/24 │ │ 10.0.12.0/24 │ │
│ │ (AZ 1) │ │ (AZ 2) │ │
│ │ │ │ │ │
│ │ ┌─────┐ ┌─────┐ │ │ ┌─────┐ ┌─────┐ │ │
│ │ │ VM │ │ VM │ │ │ │ VM │ │ VM │ │ │
│ │ └─────┘ └─────┘ │ │ └─────┘ └─────┘ │ │
│ └───────────────────────┘ └──────────────────────────┘ │
│ │
│ ┌───────────────────────┐ ┌──────────────────────────┐ │
│ │ Database Subnet │ │ Database Subnet │ │
│ │ 10.0.21.0/24 │ │ 10.0.22.0/24 │ │
│ │ │ │ │ │
│ │ ┌──────────┐ │ │ ┌──────────┐ │ │
│ │ │ SQL DB │ │ │ │ SQL DB │ │ │
│ │ └──────────┘ │ │ └──────────┘ │ │
│ └───────────────────────┘ └──────────────────────────┘ │
└─────────────────────────────────────────────────────────────┘
VNet CLI Examples
# Create VNet
az network vnet create \
--resource-group myResourceGroup \
--name myVNet \
--address-prefix 10.0.0.0/16 \
--location eastus
# Create subnet
az network vnet subnet create \
--resource-group myResourceGroup \
--vnet-name myVNet \
--name PublicSubnet \
--address-prefixes 10.0.1.0/24
az network vnet subnet create \
--resource-group myResourceGroup \
--vnet-name myVNet \
--name PrivateSubnet \
--address-prefixes 10.0.11.0/24
# List VNets
az network vnet list --output table
# List subnets
az network vnet subnet list \
--resource-group myResourceGroup \
--vnet-name myVNet \
--output table
# Create Network Security Group (NSG)
az network nsg create \
--resource-group myResourceGroup \
--name myNSG
# Add NSG rule
az network nsg rule create \
--resource-group myResourceGroup \
--nsg-name myNSG \
--name AllowHTTP \
--priority 100 \
--source-address-prefixes '*' \
--source-port-ranges '*' \
--destination-address-prefixes '*' \
--destination-port-ranges 80 \
--access Allow \
--protocol Tcp \
--direction Inbound
az network nsg rule create \
--resource-group myResourceGroup \
--nsg-name myNSG \
--name AllowSSH \
--priority 110 \
--source-address-prefixes 'VirtualNetwork' \
--source-port-ranges '*' \
--destination-address-prefixes '*' \
--destination-port-ranges 22 \
--access Allow \
--protocol Tcp \
--direction Inbound
# Associate NSG with subnet
az network vnet subnet update \
--resource-group myResourceGroup \
--vnet-name myVNet \
--name PublicSubnet \
--network-security-group myNSG
# Create NAT Gateway
az network public-ip create \
--resource-group myResourceGroup \
--name myNATGatewayIP \
--sku Standard \
--allocation-method Static
az network nat gateway create \
--resource-group myResourceGroup \
--name myNATGateway \
--public-ip-addresses myNATGatewayIP \
--idle-timeout 10
# Associate NAT Gateway with subnet
az network vnet subnet update \
--resource-group myResourceGroup \
--vnet-name myVNet \
--name PrivateSubnet \
--nat-gateway myNATGateway
# VNet peering
az network vnet peering create \
--resource-group myResourceGroup \
--name myVNet-to-VNet2 \
--vnet-name myVNet \
--remote-vnet myVNet2 \
--allow-vnet-access
Azure Load Balancer
Distribute traffic across resources.
Load Balancer Types
Type SKU OSI Layer Use Case Cost
───────────────────────────────────────────────────────────────────────
Load Balancer Basic Layer 4 Internal/Public Free
Load Balancer Standard Layer 4 Production $$
Application Standard Layer 7 HTTP/HTTPS routing $$
Gateway
# Create public IP
az network public-ip create \
--resource-group myResourceGroup \
--name myPublicIP \
--sku Standard
# Create load balancer
az network lb create \
--resource-group myResourceGroup \
--name myLoadBalancer \
--sku Standard \
--public-ip-address myPublicIP \
--frontend-ip-name myFrontEnd \
--backend-pool-name myBackEndPool
# Create health probe
az network lb probe create \
--resource-group myResourceGroup \
--lb-name myLoadBalancer \
--name myHealthProbe \
--protocol tcp \
--port 80 \
--interval 15 \
--threshold 2
# Create load balancer rule
az network lb rule create \
--resource-group myResourceGroup \
--lb-name myLoadBalancer \
--name myHTTPRule \
--protocol tcp \
--frontend-port 80 \
--backend-port 80 \
--frontend-ip-name myFrontEnd \
--backend-pool-name myBackEndPool \
--probe-name myHealthProbe
# Add VM to backend pool
az network nic ip-config address-pool add \
--resource-group myResourceGroup \
--nic-name myNIC \
--ip-config-name ipconfig1 \
--lb-name myLoadBalancer \
--address-pool myBackEndPool
Azure Application Gateway
Layer 7 load balancer with WAF.
# Create Application Gateway
az network application-gateway create \
--name myAppGateway \
--resource-group myResourceGroup \
--location eastus \
--vnet-name myVNet \
--subnet PublicSubnet \
--capacity 2 \
--sku Standard_v2 \
--public-ip-address myPublicIP \
--servers 10.0.11.4 10.0.11.5
# Create path-based routing rule
az network application-gateway url-path-map create \
--gateway-name myAppGateway \
--resource-group myResourceGroup \
--name myPathMap \
--paths /images/* \
--http-settings appGatewayBackendHttpSettings \
--address-pool imagesBackendPool
# Enable Web Application Firewall (WAF)
az network application-gateway waf-config set \
--gateway-name myAppGateway \
--resource-group myResourceGroup \
--enabled true \
--firewall-mode Prevention \
--rule-set-version 3.0
Azure DNS
DNS hosting service.
# Create DNS zone
az network dns zone create \
--resource-group myResourceGroup \
--name example.com
# Create A record
az network dns record-set a add-record \
--resource-group myResourceGroup \
--zone-name example.com \
--record-set-name www \
--ipv4-address 1.2.3.4
# Create CNAME record
az network dns record-set cname set-record \
--resource-group myResourceGroup \
--zone-name example.com \
--record-set-name blog \
--cname www.example.com
# List records
az network dns record-set list \
--resource-group myResourceGroup \
--zone-name example.com
# Get nameservers
az network dns zone show \
--resource-group myResourceGroup \
--name example.com \
--query nameServers
Serverless Services
Azure Functions
Serverless compute (equivalent to AWS Lambda).
Function Runtime Versions
Runtime Languages Timeout (Consumption)
───────────────────────────────────────────────────────────────
4.x (Current) C#, Java, JavaScript, 10 minutes (default)
Python, PowerShell, TypeScript
Function Triggers
Trigger Type Use Case
────────────────────────────────────────────────────
HTTP REST APIs, webhooks
Timer Scheduled tasks
Blob Storage File processing
Queue Storage Async processing
Event Grid Event-driven workflows
Event Hub Real-time data streams
Service Bus Enterprise messaging
Cosmos DB Database change feed
Function Example (Python)
import logging
import azure.functions as func
def main(req: func.HttpRequest) -> func.HttpResponse:
logging.info('Python HTTP trigger function processed a request.')
name = req.params.get('name')
if not name:
try:
req_body = req.get_json()
name = req_body.get('name')
except ValueError:
pass
if name:
return func.HttpResponse(
f"Hello, {name}!",
status_code=200
)
else:
return func.HttpResponse(
"Please pass a name parameter",
status_code=400
)
# Blob trigger example
def main(myblob: func.InputStream):
logging.info(f"Processing blob: {myblob.name}")
logging.info(f"Blob size: {myblob.length} bytes")
# Process the blob
content = myblob.read()
# Do something with content
# Timer trigger example
def main(mytimer: func.TimerRequest) -> None:
logging.info('Timer trigger function executed.')
if mytimer.past_due:
logging.info('The timer is past due!')
# Perform scheduled task
perform_maintenance()
# Queue trigger example
def main(msg: func.QueueMessage) -> None:
logging.info(f'Processing queue message: {msg.get_body().decode("utf-8")}')
# Process message
process_order(msg.get_json())
function.json Configuration
{
"bindings": [
{
"authLevel": "function",
"type": "httpTrigger",
"direction": "in",
"name": "req",
"methods": ["get", "post"]
},
{
"type": "http",
"direction": "out",
"name": "$return"
}
]
}
Azure Functions CLI Examples
# Install Azure Functions Core Tools
npm install -g azure-functions-core-tools@4
# Create function app locally
func init myFunctionApp --python
cd myFunctionApp
# Create new function
func new --name HttpTrigger --template "HTTP trigger"
# Run locally
func start
# Create function app in Azure
az functionapp create \
--resource-group myResourceGroup \
--consumption-plan-location eastus \
--runtime python \
--runtime-version 3.11 \
--functions-version 4 \
--name myuniquefunctionapp123 \
--storage-account mystorageaccount123 \
--os-type Linux
# Deploy to Azure
func azure functionapp publish myuniquefunctionapp123
# View logs
func azure functionapp logstream myuniquefunctionapp123
# Set application settings
az functionapp config appsettings set \
--name myuniquefunctionapp123 \
--resource-group myResourceGroup \
--settings "DB_CONNECTION_STRING=Server=..."
# Enable managed identity
az functionapp identity assign \
--name myuniquefunctionapp123 \
--resource-group myResourceGroup
# List functions
az functionapp function list \
--name myuniquefunctionapp123 \
--resource-group myResourceGroup
# Delete function app
az functionapp delete \
--name myuniquefunctionapp123 \
--resource-group myResourceGroup
Azure Functions Pricing
Plan Price Timeout Scaling
──────────────────────────────────────────────────────────────
Consumption $0.20/million requests 10 min Automatic
+ $0.000016/GB-s
Premium $0.169/vCPU hour Unlimited Automatic
+ $0.0123/GB hour
Dedicated App Service Plan cost Unlimited Manual/Auto
Free Tier: 1M requests + 400,000 GB-s/month
Azure Logic Apps
Workflow automation (similar to AWS Step Functions).
# Create Logic App
az logic workflow create \
--resource-group myResourceGroup \
--location eastus \
--name myLogicApp \
--definition @workflow.json
# List Logic Apps
az logic workflow list \
--resource-group myResourceGroup
# Show Logic App
az logic workflow show \
--resource-group myResourceGroup \
--name myLogicApp
# Run Logic App
az logic workflow run trigger \
--resource-group myResourceGroup \
--name myLogicApp \
--trigger-name manual
Container Services
Azure Container Instances (ACI)
Serverless containers (similar to AWS Fargate).
# Create container instance
az container create \
--resource-group myResourceGroup \
--name mycontainer \
--image nginx:latest \
--cpu 1 \
--memory 1.5 \
--dns-name-label myuniquecontainer123 \
--ports 80
# List containers
az container list --output table
# Get container logs
az container logs \
--resource-group myResourceGroup \
--name mycontainer
# Execute command in container
az container exec \
--resource-group myResourceGroup \
--name mycontainer \
--exec-command "/bin/bash"
# Delete container
az container delete \
--resource-group myResourceGroup \
--name mycontainer \
--yes
# Create container with environment variables
az container create \
--resource-group myResourceGroup \
--name myapp \
--image myregistry.azurecr.io/myapp:latest \
--cpu 2 \
--memory 4 \
--environment-variables \
'DB_HOST'='mydb.database.windows.net' \
'DB_NAME'='mydb' \
--secure-environment-variables \
'DB_PASSWORD'='SecurePassword123!' \
--registry-login-server myregistry.azurecr.io \
--registry-username myregistry \
--registry-password <password>
Azure Kubernetes Service (AKS)
Managed Kubernetes.
# Create AKS cluster
az aks create \
--resource-group myResourceGroup \
--name myAKSCluster \
--node-count 3 \
--node-vm-size Standard_D2s_v3 \
--enable-managed-identity \
--generate-ssh-keys \
--network-plugin azure \
--enable-addons monitoring
# Get credentials
az aks get-credentials \
--resource-group myResourceGroup \
--name myAKSCluster
# Verify connection
kubectl get nodes
# Scale cluster
az aks scale \
--resource-group myResourceGroup \
--name myAKSCluster \
--node-count 5
# Upgrade cluster
az aks upgrade \
--resource-group myResourceGroup \
--name myAKSCluster \
--kubernetes-version 1.28.0
# Enable cluster autoscaler
az aks update \
--resource-group myResourceGroup \
--name myAKSCluster \
--enable-cluster-autoscaler \
--min-count 3 \
--max-count 10
# List available versions
az aks get-versions --location eastus --output table
# Delete cluster
az aks delete \
--resource-group myResourceGroup \
--name myAKSCluster \
--yes
Azure Container Registry (ACR)
Docker registry (similar to AWS ECR).
# Create container registry
az acr create \
--resource-group myResourceGroup \
--name myuniqueregistry123 \
--sku Basic
# Login to registry
az acr login --name myuniqueregistry123
# Tag image
docker tag myapp:latest myuniqueregistry123.azurecr.io/myapp:v1.0
# Push image
docker push myuniqueregistry123.azurecr.io/myapp:v1.0
# List images
az acr repository list --name myuniqueregistry123 --output table
# List tags
az acr repository show-tags \
--name myuniqueregistry123 \
--repository myapp \
--output table
# Delete image
az acr repository delete \
--name myuniqueregistry123 \
--image myapp:v1.0 \
--yes
Security Services
Azure Active Directory (Azure Entra ID)
Identity and access management.
# Create user
az ad user create \
--display-name "Alice Smith" \
--user-principal-name alice@contoso.com \
--password SecurePassword123!
# Create group
az ad group create \
--display-name Developers \
--mail-nickname developers
# Add user to group
az ad group member add \
--group Developers \
--member-id <user-object-id>
# Create service principal
az ad sp create-for-rbac \
--name myServicePrincipal \
--role Contributor \
--scopes /subscriptions/<subscription-id>
# List users
az ad user list --output table
# List groups
az ad group list --output table
Azure Key Vault
Secrets management (similar to AWS Secrets Manager).
# Create Key Vault
az keyvault create \
--name myuniquekeyvault123 \
--resource-group myResourceGroup \
--location eastus
# Set secret
az keyvault secret set \
--vault-name myuniquekeyvault123 \
--name dbpassword \
--value "SecurePassword123!"
# Get secret
az keyvault secret show \
--vault-name myuniquekeyvault123 \
--name dbpassword \
--query value \
--output tsv
# List secrets
az keyvault secret list \
--vault-name myuniquekeyvault123 \
--output table
# Delete secret
az keyvault secret delete \
--vault-name myuniquekeyvault123 \
--name dbpassword
# Set access policy
az keyvault set-policy \
--name myuniquekeyvault123 \
--upn alice@contoso.com \
--secret-permissions get list set delete
# Create certificate
az keyvault certificate create \
--vault-name myuniquekeyvault123 \
--name mycert \
--policy "$(az keyvault certificate get-default-policy)"
# Import certificate
az keyvault certificate import \
--vault-name myuniquekeyvault123 \
--name imported-cert \
--file certificate.pfx
Use Key Vault in Application (Python)
from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
# Create client
credential = DefaultAzureCredential()
vault_url = "https://myuniquekeyvault123.vault.azure.net"
client = SecretClient(vault_url=vault_url, credential=credential)
# Get secret
secret = client.get_secret("dbpassword")
print(f"Secret value: {secret.value}")
# Set secret
client.set_secret("newsecret", "newvalue")
# List secrets
secrets = client.list_properties_of_secrets()
for secret in secrets:
print(f"Secret name: {secret.name}")
# Delete secret
client.begin_delete_secret("newsecret").wait()
Azure RBAC (Role-Based Access Control)
# List role definitions
az role definition list --output table
# Assign role to user
az role assignment create \
--assignee alice@contoso.com \
--role Contributor \
--scope /subscriptions/<subscription-id>/resourceGroups/myResourceGroup
# Assign role to service principal
az role assignment create \
--assignee <service-principal-object-id> \
--role "Storage Blob Data Contributor" \
--scope /subscriptions/<subscription-id>/resourceGroups/myResourceGroup/providers/Microsoft.Storage/storageAccounts/mystorageaccount123
# List role assignments
az role assignment list \
--assignee alice@contoso.com \
--output table
# Remove role assignment
az role assignment delete \
--assignee alice@contoso.com \
--role Contributor \
--scope /subscriptions/<subscription-id>/resourceGroups/myResourceGroup
# Create custom role
az role definition create --role-definition @custom-role.json
Custom Role Definition Example
{
"Name": "Custom VM Operator",
"Description": "Can start and stop VMs",
"Actions": [
"Microsoft.Compute/virtualMachines/start/action",
"Microsoft.Compute/virtualMachines/restart/action",
"Microsoft.Compute/virtualMachines/deallocate/action",
"Microsoft.Compute/virtualMachines/read"
],
"NotActions": [],
"AssignableScopes": [
"/subscriptions/<subscription-id>"
]
}
Monitoring and Management
Azure Monitor
Monitoring and observability (similar to CloudWatch).
# Create action group
az monitor action-group create \
--name myActionGroup \
--resource-group myResourceGroup \
--short-name myAG \
--email-receiver name=admin email=admin@example.com
# Create metric alert
az monitor metrics alert create \
--name high-cpu \
--resource-group myResourceGroup \
--scopes /subscriptions/<subscription-id>/resourceGroups/myResourceGroup/providers/Microsoft.Compute/virtualMachines/myVM \
--condition "avg Percentage CPU > 80" \
--window-size 5m \
--evaluation-frequency 1m \
--action myActionGroup
# List alerts
az monitor metrics alert list \
--resource-group myResourceGroup
# Query metrics
az monitor metrics list \
--resource /subscriptions/<subscription-id>/resourceGroups/myResourceGroup/providers/Microsoft.Compute/virtualMachines/myVM \
--metric "Percentage CPU" \
--start-time 2024-01-01T00:00:00Z \
--end-time 2024-01-01T23:59:59Z \
--interval PT1H
Azure Log Analytics
Log collection and analysis.
# Create Log Analytics workspace
az monitor log-analytics workspace create \
--resource-group myResourceGroup \
--workspace-name myWorkspace \
--location eastus
# Query logs (KQL - Kusto Query Language)
az monitor log-analytics query \
--workspace myWorkspace \
--analytics-query "AzureActivity | where TimeGenerated > ago(1h) | summarize count() by OperationName"
# Example KQL queries
# All logs from last hour
"AzureActivity | where TimeGenerated > ago(1h)"
# Count errors by resource
"AzureDiagnostics | where Level == 'Error' | summarize count() by Resource"
# VM performance - CPU over 80%
"Perf | where CounterName == '% Processor Time' and CounterValue > 80"
# Failed login attempts
"SigninLogs | where ResultType != 0 | project TimeGenerated, UserPrincipalName, ResultType, ResultDescription"
Azure Application Insights
Application performance monitoring.
from applicationinsights import TelemetryClient
# Initialize client
tc = TelemetryClient('<instrumentation-key>')
# Track event
tc.track_event('UserLogin', {'user': 'alice@example.com'})
# Track metric
tc.track_metric('request_duration', 125.5)
# Track exception
try:
result = 1 / 0
except Exception as e:
tc.track_exception()
# Track request
tc.track_request('GET /api/users', 'https://myapi.com/api/users', True, 200, 125)
# Track dependency
tc.track_dependency('SQL', 'mydb.database.windows.net', 'SELECT * FROM users', 45, True, 'Query')
# Flush telemetry
tc.flush()
# Enable Application Insights for web app
az webapp config appsettings set \
--resource-group myResourceGroup \
--name myUniqueWebApp123 \
--settings "APPINSIGHTS_INSTRUMENTATIONKEY=<instrumentation-key>"
DevOps and CI/CD
Azure DevOps
Complete DevOps platform.
Azure Pipelines YAML Example
# azure-pipelines.yml
trigger:
- main
pool:
vmImage: 'ubuntu-latest'
variables:
buildConfiguration: 'Release'
stages:
- stage: Build
jobs:
- job: BuildJob
steps:
- task: UsePythonVersion@0
inputs:
versionSpec: '3.11'
- script: |
python -m pip install --upgrade pip
pip install -r requirements.txt
displayName: 'Install dependencies'
- script: |
pytest tests/ --junitxml=junit/test-results.xml
displayName: 'Run tests'
- task: PublishTestResults@2
inputs:
testResultsFiles: '**/test-results.xml'
- script: |
docker build -t myapp:$(Build.BuildId) .
displayName: 'Build Docker image'
- task: Docker@2
inputs:
containerRegistry: 'myACR'
repository: 'myapp'
command: 'push'
tags: |
$(Build.BuildId)
latest
- stage: Deploy
dependsOn: Build
jobs:
- deployment: DeployJob
environment: 'production'
strategy:
runOnce:
deploy:
steps:
- task: AzureWebAppContainer@1
inputs:
azureSubscription: 'myServiceConnection'
appName: 'myUniqueWebApp123'
containers: 'myregistry.azurecr.io/myapp:$(Build.BuildId)'
Azure CLI for DevOps
# Create Azure DevOps project
az devops project create --name MyProject --org https://dev.azure.com/myorg
# Create pipeline
az pipelines create \
--name MyPipeline \
--repository https://github.com/user/repo \
--branch main \
--yml-path azure-pipelines.yml
# Run pipeline
az pipelines run --name MyPipeline
# List pipelines
az pipelines list --output table
# Show pipeline runs
az pipelines runs list --pipeline-name MyPipeline --output table
AI and Machine Learning
Azure OpenAI Service
Access to OpenAI models (GPT-4, GPT-3.5, DALL-E, Whisper).
import openai
# Configure
openai.api_type = "azure"
openai.api_base = "https://myopenai.openai.azure.com/"
openai.api_version = "2023-05-15"
openai.api_key = "<api-key>"
# Generate completion
response = openai.ChatCompletion.create(
engine="gpt-4", # deployment name
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Explain cloud computing in simple terms."}
],
temperature=0.7,
max_tokens=800
)
print(response.choices[0].message.content)
# Generate image
response = openai.Image.create(
prompt="A futuristic cloud data center",
n=1,
size="1024x1024"
)
image_url = response['data'][0]['url']
Azure Cognitive Services
Pre-built AI services.
from azure.ai.textanalytics import TextAnalyticsClient
from azure.core.credentials import AzureKeyCredential
# Text Analytics
endpoint = "https://myservice.cognitiveservices.azure.com/"
key = "<api-key>"
client = TextAnalyticsClient(endpoint=endpoint, credential=AzureKeyCredential(key))
# Sentiment analysis
documents = ["I love Azure!", "This is terrible."]
result = client.analyze_sentiment(documents)
for doc in result:
print(f"Sentiment: {doc.sentiment}, Confidence: {doc.confidence_scores}")
# Entity recognition
result = client.recognize_entities(["Microsoft was founded by Bill Gates."])
for doc in result:
for entity in doc.entities:
print(f"Entity: {entity.text}, Category: {entity.category}")
# Key phrase extraction
result = client.extract_key_phrases(["Azure is a cloud computing platform."])
for doc in result:
print(f"Key phrases: {doc.key_phrases}")
Azure Machine Learning
End-to-end ML platform.
from azureml.core import Workspace, Experiment, ScriptRunConfig
# Connect to workspace
ws = Workspace.from_config()
# Create experiment
experiment = Experiment(workspace=ws, name='my-experiment')
# Configure training run
config = ScriptRunConfig(
source_directory='./src',
script='train.py',
compute_target='cpu-cluster',
environment='AzureML-sklearn-1.0'
)
# Submit run
run = experiment.submit(config)
run.wait_for_completion(show_output=True)
# Register model
model = run.register_model(
model_name='my-model',
model_path='outputs/model.pkl'
)
# Deploy model
from azureml.core.webservice import AciWebservice
from azureml.core.model import InferenceConfig
inference_config = InferenceConfig(
entry_script='score.py',
environment='AzureML-sklearn-1.0'
)
aci_config = AciWebservice.deploy_configuration(
cpu_cores=1,
memory_gb=1
)
service = Model.deploy(
workspace=ws,
name='my-service',
models=[model],
inference_config=inference_config,
deployment_config=aci_config
)
service.wait_for_deployment(show_output=True)
print(f"Scoring URI: {service.scoring_uri}")
Architecture Examples
Three-Tier Web Application
Internet
│
┌────────▼────────┐
│ Azure Front │ CDN
│ Door │
└────────┬────────┘
│
┌────────▼────────┐
│ Azure DNS │
└────────┬────────┘
│
┌────────────────────────────▼────────────────────────────────┐
│ Virtual Network │
│ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ Public Subnet (AZ-1) Public Subnet (AZ-2) │ │
│ │ ┌──────────────────┐ ┌──────────────────┐ │ │
│ │ │ Application │ │ Application │ │ │
│ │ │ Gateway + WAF │ │ Gateway + WAF │ │ │
│ │ └────────┬─────────┘ └────────┬─────────┘ │ │
│ └───────────┼──────────────────────┼─────────────────┘ │
│ │ │ │
│ ┌───────────▼──────────────────────▼─────────────────┐ │
│ │ Private Subnet (AZ-1) Private Subnet (AZ-2) │ │
│ │ ┌────────────────┐ ┌────────────────┐ │ │
│ │ │ VMSS │ │ VMSS │ │ │
│ │ │ ┌──┐ ┌──┐ │ │ ┌──┐ ┌──┐ │ │ │
│ │ │ │VM│ │VM│ │ │ │VM│ │VM│ │ │ │
│ │ │ └──┘ └──┘ │ │ └──┘ └──┘ │ │ │
│ │ └────────┬───────┘ └────────┬───────┘ │ │
│ └───────────┼──────────────────────┼─────────────────┘ │
│ │ │ │
│ ┌───────────▼──────────────────────▼─────────────────┐ │
│ │ Database Subnet (AZ-1) Database Subnet (AZ-2) │ │
│ │ ┌──────────────┐ ┌──────────────┐ │ │
│ │ │ Azure SQL │◄──────▶│ Azure SQL │ │ │
│ │ │ Primary │ │ Secondary │ │ │
│ │ └──────────────┘ └──────────────┘ │ │
│ │ │ │
│ │ ┌──────────────────────────────┐ │ │
│ │ │ Azure Cache for Redis │ │ │
│ │ └──────────────────────────────┘ │ │
│ └──────────────────────────────────────────────────────┘ │
│ │
│ Additional Services: │
│ ├─ Blob Storage: Static assets │
│ ├─ Key Vault: Secrets management │
│ ├─ Monitor: Monitoring and alerts │
│ └─ Application Insights: APM │
└──────────────────────────────────────────────────────────────┘
Serverless Microservices
┌─────────────┐
│ Users │
└──────┬──────┘
│
┌────────▼────────┐
│ Azure Front │
│ Door + Blob │
│ (Frontend) │
└────────┬────────┘
│
┌────────▼────────┐
│ API Management │
└────────┬────────┘
│
┌──────────────────────┼──────────────────────┐
│ │ │
┌────▼──────┐ ┌────▼──────┐ ┌────▼──────┐
│ Function │ │ Function │ │ Function │
│ User Svc │ │ Order Svc │ │ Pay Svc │
└────┬──────┘ └────┬──────┘ └────┬──────┘
│ │ │
┌────▼──────┐ ┌────▼──────┐ ┌────▼──────┐
│Cosmos DB │ │Cosmos DB │ │Cosmos DB │
│Users │ │Orders │ │Payments │
└───────────┘ └───────────┘ └───────────┘
│ │ │
└─────────────────────┼─────────────────────┘
│
┌──────▼──────┐
│ Event Grid │
│ Service Bus│
└─────────────┘
Azure vs AWS Comparison
Service Mapping
Service Category Azure AWS
─────────────────────────────────────────────────────────────────
Compute
VMs Virtual Machines EC2
Auto-scaling VMSS Auto Scaling
Serverless Functions Lambda
Containers AKS / ACI EKS / ECS / Fargate
PaaS App Service Elastic Beanstalk
Storage
Object Blob Storage S3
Block Managed Disks EBS
File Azure Files EFS
Archive Archive Storage Glacier
Database
Relational SQL Database RDS
NoSQL Document Cosmos DB DynamoDB
Cache Cache for Redis ElastiCache
Data Warehouse Synapse Analytics Redshift
Networking
Virtual Network VNet VPC
Load Balancer Load Balancer / App GW ELB / ALB / NLB
CDN Front Door / CDN CloudFront
DNS Azure DNS Route 53
VPN VPN Gateway VPN Gateway
Security
Identity Azure AD (Entra ID) IAM / Cognito
Secrets Key Vault Secrets Manager
Encryption Key Vault KMS
Monitoring
Metrics Azure Monitor CloudWatch
Logs Log Analytics CloudWatch Logs
APM Application Insights X-Ray
Audit Activity Log CloudTrail
DevOps
CI/CD Azure DevOps / Pipelines CodePipeline
Repository Azure Repos CodeCommit
Container Registry ACR ECR
AI/ML
Pre-trained Models Cognitive Services AI Services
ML Platform Machine Learning SageMaker
GenAI OpenAI Service Bedrock
Key Differences
Aspect Azure AWS
─────────────────────────────────────────────────────────────────
Market Share ~23% ~32%
Launch Year 2010 2006
Focus Enterprise / Hybrid Startups / Flexibility
Integration Microsoft stack Broad ecosystem
Regions 60+ regions 30+ regions
Pricing Per-minute billing Per-second billing
Support Strong enterprise Extensive documentation
Compliance Most certifications Extensive certifications
Hybrid Cloud Azure Arc (best-in-class) Outposts
Windows Workloads Native integration Good support
When to Choose Azure
✓ Heavy Microsoft stack usage (Windows, .NET, SQL Server)
✓ Enterprise Active Directory integration needed
✓ Hybrid cloud requirements (on-premises + cloud)
✓ Existing Microsoft licensing (Azure Hybrid Benefit)
✓ Office 365 / Dynamics 365 integration
✓ Strong compliance requirements
✓ European data centers needed
✓ .NET development team
When to Choose AWS
✓ Largest service selection needed
✓ Startup with flexible requirements
✓ Open-source technologies focus
✓ Mature ecosystem and tooling important
✓ Broadest region availability needed
✓ Extensive third-party integrations
✓ Strong serverless requirements
✓ Largest community and resources
Cost Optimization
Azure Cost Management
# Create budget
az consumption budget create \
--budget-name myBudget \
--amount 1000 \
--category Cost \
--time-grain Monthly \
--start-date 2024-01-01 \
--end-date 2024-12-31
# View cost analysis
az consumption usage list \
--start-date 2024-01-01 \
--end-date 2024-01-31
# Get cost forecast
az consumption forecast list
# Enable auto-shutdown for VMs
az vm auto-shutdown \
--resource-group myResourceGroup \
--name myVM \
--time 1900 \
--timezone "Pacific Standard Time"
Cost Optimization Strategies
┌──────────────────────────────────────────────────────────┐
│ Azure Cost Optimization Checklist │
├──────────────────────────────────────────────────────────┤
│ │
│ Compute │
│ ☐ Use Reserved Instances (up to 72% discount) │
│ ☐ Use Spot VMs for fault-tolerant workloads │
│ ☐ Right-size VMs based on metrics │
│ ☐ Use Azure Hybrid Benefit for Windows/SQL │
│ ☐ Deallocate VMs when not in use │
│ ☐ Use Azure Functions for event-driven workloads │
│ ☐ Enable auto-shutdown for dev/test VMs │
│ │
│ Storage │
│ ☐ Use lifecycle management policies │
│ ☐ Move infrequent data to Cool/Archive tiers │
│ ☐ Delete unused disks and snapshots │
│ ☐ Use LRS instead of GRS when possible │
│ ☐ Enable blob versioning only when needed │
│ │
│ Database │
│ ☐ Use serverless for SQL Database with variable load │
│ ☐ Right-size database tiers │
│ ☐ Use Cosmos DB autoscale │
│ ☐ Implement connection pooling │
│ ☐ Pause dev/test databases when not in use │
│ │
│ Network │
│ ☐ Use Azure Front Door to reduce data transfer │
│ ☐ Use VNet peering instead of VPN when possible │
│ ☐ Consolidate data transfer within same region │
│ ☐ Use private endpoints to avoid data transfer costs │
│ │
│ Monitoring │
│ ☐ Set up Azure Cost Management + Billing alerts │
│ ☐ Use Azure Advisor cost recommendations │
│ ☐ Review Advisor score regularly │
│ ☐ Use tags for cost allocation │
│ ☐ Review Underutilized Resources report │
└──────────────────────────────────────────────────────────┘
Azure Pricing Calculator
Use Azure Pricing Calculator: https://azure.microsoft.com/pricing/calculator/
Example Monthly Costs
Service Configuration Monthly Cost (Approx)
─────────────────────────────────────────────────────────────────────
VM (B2s) 2 vCPU, 4GB, Linux $30
Managed Disk (128GB) Premium SSD $20
SQL Database (S0) 10 DTUs $15
Cosmos DB 400 RU/s $24
Blob Storage (100GB) Hot tier $2
Data Transfer 50GB outbound $4
App Service (B1) 1 core, 1.75GB $55
Functions 1M requests $0.20
─────────
Total: ~$150/month
Best Practices
Security Best Practices
1. Identity and Access
├─ Use Azure AD (Entra ID) for all authentication
├─ Enable MFA for all users
├─ Use managed identities instead of service principals
├─ Implement RBAC with least privilege
├─ Use Azure AD Privileged Identity Management (PIM)
└─ Enable Conditional Access policies
2. Network Security
├─ Use Network Security Groups (NSGs)
├─ Implement Azure Firewall or third-party NVA
├─ Use private endpoints for PaaS services
├─ Enable DDoS Protection Standard for production
├─ Use Application Gateway with WAF
└─ Enable VNet service endpoints
3. Data Protection
├─ Enable encryption at rest for all services
├─ Use Azure Key Vault for secrets
├─ Enable TLS 1.2+ for data in transit
├─ Implement backup and disaster recovery
├─ Enable soft delete for Key Vault and Storage
└─ Use customer-managed keys when required
4. Monitoring and Compliance
├─ Enable Azure Security Center (Defender for Cloud)
├─ Use Azure Sentinel for SIEM
├─ Enable Azure Monitor and Log Analytics
├─ Implement Azure Policy for governance
├─ Use Azure Blueprints for compliance
└─ Regular security assessments
5. Application Security
├─ Use Web Application Firewall (WAF)
├─ Implement API Management security features
├─ Enable Application Insights
├─ Use Azure Front Door for global apps
└─ Regular vulnerability scanning
Reliability Best Practices
1. High Availability
├─ Deploy across Availability Zones
├─ Use zone-redundant services
├─ Implement auto-scaling
├─ Use Azure Load Balancer / Application Gateway
└─ Consider multi-region for critical workloads
2. Disaster Recovery
├─ Define RPO and RTO requirements
├─ Use Azure Site Recovery
├─ Implement geo-redundant storage
├─ Regular backup and restore testing
└─ Document DR procedures
3. Monitoring
├─ Use Azure Monitor for all resources
├─ Set up alerts for critical metrics
├─ Implement health checks
├─ Use Application Insights for APM
└─ Create dashboards for visibility
4. Resilience
├─ Implement retry logic
├─ Use circuit breaker pattern
├─ Implement graceful degradation
├─ Use queue-based load leveling
└─ Regular chaos engineering tests
CLI Reference
Common CLI Patterns
# Use --output for different formats
az vm list --output table
az vm list --output json
az vm list --output yaml
az vm list --output tsv
# Use --query for filtering (JMESPath)
az vm list --query "[].{name:name, powerState:powerState}"
az vm list --query "[?powerState=='VM running'].name"
# Use --resource-group shorthand
az vm list -g myResourceGroup
# Use --verbose for debugging
az vm create --verbose ...
# Get help
az vm --help
az vm create --help
# Interactive mode
az interactive
# Configure defaults
az configure --defaults group=myResourceGroup location=eastus
# Show defaults
az configure --list-defaults
Useful Aliases
# Add to ~/.bashrc or ~/.zshrc
alias azvm='az vm list --output table'
alias azrunning='az vm list --query "[?powerState=='\''VM running'\''].{name:name, resourceGroup:resourceGroup}" --output table'
alias azstorage='az storage account list --output table'
alias azsql='az sql db list --output table'
alias azgroup='az group list --output table'
Certification Paths
Azure Certification Roadmap
Foundational
│
└─ AZ-900: Azure Fundamentals
│
├─ Associate Level
│ ├─ AZ-104: Azure Administrator
│ ├─ AZ-204: Azure Developer
│ └─ AZ-400: DevOps Engineer
│
└─ Expert Level
├─ AZ-305: Azure Solutions Architect
└─ AZ-400: DevOps Engineer (with AZ-104/204)
Specialty (Optional)
├─ AZ-500: Security Technologies
├─ AI-102: AI Engineer
├─ DP-203: Data Engineer
└─ AZ-700: Network Engineer
Resources
Official Documentation
- Azure Documentation: https://docs.microsoft.com/azure
- Azure CLI Reference: https://docs.microsoft.com/cli/azure/
- Azure SDK Documentation: https://azure.github.io/azure-sdk/
Learning Resources
- Microsoft Learn: https://learn.microsoft.com/training/
- Azure Free Account: https://azure.microsoft.com/free/
- Azure Architecture Center: https://docs.microsoft.com/azure/architecture/
- Azure Samples: https://github.com/Azure-Samples
- Azure Friday: https://azure.microsoft.com/resources/videos/azure-friday/
Community
- r/AZURE: Reddit community
- Microsoft Q&A: https://docs.microsoft.com/answers/
- Azure Community Support: https://azure.microsoft.com/support/community/
- Azure User Groups: https://www.meetup.com/pro/azureug
Tools
- Azure CLI: Command-line interface
- Azure PowerShell: PowerShell modules
- Azure SDKs: Python, JavaScript, Java, .NET, Go
- Bicep: Azure-native IaC
- Terraform: Multi-cloud IaC
- Azure Storage Explorer: GUI for storage
- Azure Data Studio: Database management
Pricing
- Azure Pricing Calculator: https://azure.microsoft.com/pricing/calculator/
- Azure Cost Management: https://azure.microsoft.com/services/cost-management/
- Total Cost of Ownership (TCO) Calculator: https://azure.microsoft.com/pricing/tco/
Updated: January 2025
Tools
This section provides an overview of various tools that can enhance your productivity and efficiency in different domains. Each tool is accompanied by a detailed guide on how to use it effectively.
List of Tools
- tmux: A terminal multiplexer that allows you to switch between several programs in one terminal, detach them, and reattach them to a different terminal.
- vim: A highly configurable text editor built to enable efficient text editing.
- cscope: A developer's tool for browsing source code in a terminal environment.
- ctags: A programming tool that generates an index (or tag) file of names found in source and header files.
- mdbook: A command line tool to create books with Markdown.
- sed: A stream editor for filtering and transforming text.
- awk: A programming language designed for text processing and typically used as a data extraction and reporting tool.
- curl: A command-line tool for transferring data with URLs.
- wget: A free utility for non-interactive download of files from the web.
- grep: A command-line utility for searching plain-text data sets for lines that match a regular expression.
- find: A command-line utility that searches for files in a directory hierarchy.
- ffmpeg: A complete, cross-platform solution to record, convert and stream audio and video.
- make: A build automation tool that automatically builds executable programs and libraries from source code.
- Docker: A set of platform-as-a-service products that use OS-level virtualization to deliver software in packages called containers.
- Ansible: An open-source software provisioning, configuration management, and application-deployment tool.
- wpa_supplicant: WiFi client authentication daemon for connecting to wireless networks.
- hostapd: WiFi access point and authentication server for creating wireless access points.
Each tool listed above has its own dedicated page with detailed instructions on how to install, configure, and use it effectively. Click on the tool name to navigate to its respective guide.
tmux
tmux (terminal multiplexer) is a powerful tool that allows you to create, access, and control multiple terminal sessions from a single window. It enables session persistence, split panes, and window management.
Overview
tmux allows you to:
- Run multiple terminal sessions in a single window
- Split your terminal into multiple panes
- Detach and reattach sessions (sessions persist after disconnection)
- Share sessions between users
- Script and automate terminal workflows
Key Concepts:
- Session: A collection of windows, managed independently
- Window: A single screen within a session (like a tab)
- Pane: A split section within a window
- Prefix Key: Default
Ctrl+b, used before tmux commands - Detach: Disconnect from session (keeps running in background)
- Attach: Reconnect to an existing session
Installation
# Ubuntu/Debian
sudo apt update
sudo apt install tmux
# macOS
brew install tmux
# CentOS/RHEL
sudo yum install tmux
# Arch Linux
sudo pacman -S tmux
# Verify installation
tmux -V
Basic Usage
Session Management
# Start new session
tmux
# Start new session with name
tmux new -s mysession
tmux new-session -s mysession
# List sessions
tmux ls
tmux list-sessions
# Attach to session
tmux attach
tmux a
# Attach to specific session
tmux attach -t mysession
tmux a -t mysession
# Detach from session (inside tmux)
# Press: Ctrl+b, then d
# Kill session
tmux kill-session -t mysession
# Kill all sessions
tmux kill-server
# Rename session (inside tmux)
# Press: Ctrl+b, then $
Window Management
# Inside tmux, press Ctrl+b then:
# c - Create new window
# , - Rename current window
# w - List windows
# n - Next window
# p - Previous window
# 0-9 - Switch to window number
# l - Last active window
# & - Kill current window
# f - Find window by name
Pane Management
# Inside tmux, press Ctrl+b then:
# % - Split pane vertically
# " - Split pane horizontally
# Arrow keys - Navigate between panes
# o - Switch to next pane
# ; - Toggle between current and previous pane
# x - Kill current pane
# z - Toggle pane zoom (fullscreen)
# Space - Toggle between layouts
# { - Move pane left
# } - Move pane right
# Ctrl+Arrow - Resize pane
# q - Show pane numbers (then press number to switch)
Configuration
Basic .tmux.conf
# Create configuration file
cat << 'EOF' > ~/.tmux.conf
# Change prefix from Ctrl+b to Ctrl+a
set-option -g prefix C-a
unbind-key C-b
bind-key C-a send-prefix
# Enable mouse support
set -g mouse on
# Start windows and panes at 1, not 0
set -g base-index 1
setw -g pane-base-index 1
# Renumber windows when one is closed
set -g renumber-windows on
# Increase scrollback buffer size
set -g history-limit 10000
# Enable 256 colors
set -g default-terminal "screen-256color"
# Reload config file
bind r source-file ~/.tmux.conf \; display "Config reloaded!"
# Split panes with | and -
bind | split-window -h
bind - split-window -v
unbind '"'
unbind %
# Switch panes using Alt+Arrow without prefix
bind -n M-Left select-pane -L
bind -n M-Right select-pane -R
bind -n M-Up select-pane -U
bind -n M-Down select-pane -D
# Set status bar
set -g status-bg black
set -g status-fg white
set -g status-interval 60
set -g status-left-length 30
set -g status-left '#[fg=green](#S) #(whoami) '
set -g status-right '#[fg=yellow]#(cut -d " " -f 1-3 /proc/loadavg)#[default] #[fg=white]%H:%M#[default]'
EOF
# Reload tmux configuration
tmux source-file ~/.tmux.conf
Advanced Configuration
cat << 'EOF' > ~/.tmux.conf
# ===== Basic Settings =====
set-option -g prefix C-a
unbind-key C-b
bind-key C-a send-prefix
# Enable mouse
set -g mouse on
# Start numbering at 1
set -g base-index 1
setw -g pane-base-index 1
# Renumber windows
set -g renumber-windows on
# History
set -g history-limit 50000
# Terminal settings
set -g default-terminal "screen-256color"
set -ga terminal-overrides ",*256col*:Tc"
# No delay for escape key
set -sg escape-time 0
# Monitor activity
setw -g monitor-activity on
set -g visual-activity off
# ===== Key Bindings =====
# Reload config
bind r source-file ~/.tmux.conf \; display "Reloaded!"
# Split panes
bind | split-window -h -c "#{pane_current_path}"
bind - split-window -v -c "#{pane_current_path}"
bind c new-window -c "#{pane_current_path}"
# Pane navigation
bind h select-pane -L
bind j select-pane -D
bind k select-pane -U
bind l select-pane -R
# Pane resizing
bind -r H resize-pane -L 5
bind -r J resize-pane -D 5
bind -r K resize-pane -U 5
bind -r L resize-pane -R 5
# Window navigation
bind -r C-h select-window -t :-
bind -r C-l select-window -t :+
# Copy mode with vi keys
setw -g mode-keys vi
bind-key -T copy-mode-vi 'v' send -X begin-selection
bind-key -T copy-mode-vi 'y' send -X copy-selection-and-cancel
# ===== Appearance =====
# Status bar
set -g status-position bottom
set -g status-justify left
set -g status-style 'bg=colour234 fg=colour137'
set -g status-left ''
set -g status-right '#[fg=colour233,bg=colour241,bold] %d/%m #[fg=colour233,bg=colour245,bold] %H:%M:%S '
set -g status-right-length 50
set -g status-left-length 20
# Window status
setw -g window-status-current-style 'fg=colour1 bg=colour19 bold'
setw -g window-status-current-format ' #I#[fg=colour249]:#[fg=colour255]#W#[fg=colour249]#F '
setw -g window-status-style 'fg=colour9 bg=colour18'
setw -g window-status-format ' #I#[fg=colour237]:#[fg=colour250]#W#[fg=colour244]#F '
# Pane borders
set -g pane-border-style 'fg=colour238'
set -g pane-active-border-style 'fg=colour51'
# Message text
set -g message-style 'fg=colour232 bg=colour166 bold'
EOF
Key Bindings Reference
Default Prefix: Ctrl+b
Session Commands
Ctrl+b d # Detach from session
Ctrl+b s # List sessions
Ctrl+b $ # Rename session
Ctrl+b ( # Switch to previous session
Ctrl+b ) # Switch to next session
Ctrl+b L # Switch to last session
Window Commands
Ctrl+b c # Create new window
Ctrl+b , # Rename current window
Ctrl+b & # Kill current window
Ctrl+b w # List windows
Ctrl+b n # Next window
Ctrl+b p # Previous window
Ctrl+b 0-9 # Switch to window by number
Ctrl+b l # Switch to last active window
Ctrl+b f # Find window
Ctrl+b . # Move window (prompts for index)
Pane Commands
Ctrl+b % # Split vertically
Ctrl+b " # Split horizontally
Ctrl+b o # Switch to next pane
Ctrl+b ; # Toggle between current and previous pane
Ctrl+b x # Kill current pane
Ctrl+b ! # Break pane into window
Ctrl+b z # Toggle pane zoom
Ctrl+b Space # Toggle between pane layouts
Ctrl+b q # Show pane numbers
Ctrl+b { # Move pane left
Ctrl+b } # Move pane right
Ctrl+b Ctrl+o # Rotate panes
Ctrl+b Arrow # Navigate panes
Copy Mode
Ctrl+b [ # Enter copy mode
Ctrl+b ] # Paste buffer
Space # Start selection (in copy mode)
Enter # Copy selection (in copy mode)
q # Exit copy mode
# With vi mode enabled:
v # Begin selection
y # Copy selection
Other Commands
Ctrl+b ? # List all key bindings
Ctrl+b : # Enter command mode
Ctrl+b t # Show time
Ctrl+b ~ # Show messages
Common Workflows
Development Environment
# Create development session
tmux new -s dev
# Inside tmux:
# Split into 3 panes
Ctrl+b % # Split vertically
Ctrl+b " # Split right pane horizontally
# Now you have:
# - Left pane: Editor (vim/emacs)
# - Top right: Run server
# - Bottom right: Git/commands
# Navigate between panes
Ctrl+b Arrow keys
Remote Server Session
# SSH to server
ssh user@server
# Start tmux session
tmux new -s work
# Do work...
# Connection drops or intentional detach
Ctrl+b d
# Reconnect later
ssh user@server
tmux attach -t work
# Your session is exactly as you left it
Pair Programming
# User 1: Create session
tmux new -s pair
# User 2: Attach to same session (read-only)
tmux attach -t pair -r
# User 2: Attach with full control
tmux attach -t pair
Multiple Projects
# Create sessions for different projects
tmux new -s project1 -d
tmux new -s project2 -d
tmux new -s project3 -d
# List all sessions
tmux ls
# Attach to specific project
tmux attach -t project1
# Switch between sessions (inside tmux)
Ctrl+b s # Shows session list
Ctrl+b ( # Previous session
Ctrl+b ) # Next session
Advanced Features
Copy and Paste
# Enter copy mode
Ctrl+b [
# Navigate with vi keys (if vi mode enabled)
# Or use arrow keys
# Start selection
Space
# Copy selection
Enter
# Paste
Ctrl+b ]
# View paste buffers
Ctrl+b #
# Choose buffer to paste
Ctrl+b =
Synchronized Panes
# Enable synchronized panes (type in all panes at once)
Ctrl+b :
:setw synchronize-panes on
# Disable
:setw synchronize-panes off
# Toggle with binding (add to .tmux.conf)
bind S setw synchronize-panes
Save and Restore Sessions
# Save session layout
Ctrl+b :
:save-buffer /tmp/tmux-session.txt
# Create script to restore layout
cat << 'EOF' > ~/restore-session.sh
#!/bin/bash
tmux new-session -d -s dev
tmux split-window -h
tmux split-window -v
tmux select-pane -t 0
tmux send-keys 'vim' C-m
tmux select-pane -t 1
tmux send-keys 'npm run dev' C-m
tmux select-pane -t 2
tmux attach -t dev
EOF
chmod +x ~/restore-session.sh
Tmux Plugins (TPM)
# Install Tmux Plugin Manager
git clone https://github.com/tmux-plugins/tpm ~/.tmux/plugins/tpm
# Add to .tmux.conf
cat << 'EOF' >> ~/.tmux.conf
# List of plugins
set -g @plugin 'tmux-plugins/tpm'
set -g @plugin 'tmux-plugins/tmux-sensible'
set -g @plugin 'tmux-plugins/tmux-resurrect'
set -g @plugin 'tmux-plugins/tmux-continuum'
# Initialize TPM (keep at bottom of .tmux.conf)
run '~/.tmux/plugins/tpm/tpm'
EOF
# Reload config
tmux source ~/.tmux.conf
# Install plugins (inside tmux)
Ctrl+b I
Custom Scripts
# Create reusable session layout
cat << 'EOF' > ~/tmux-dev.sh
#!/bin/bash
SESSION="dev"
SESSIONEXISTS=$(tmux list-sessions | grep $SESSION)
if [ "$SESSIONEXISTS" = "" ]
then
# Create new session
tmux new-session -d -s $SESSION
# Create windows
tmux rename-window -t 0 'Editor'
tmux send-keys -t 'Editor' 'cd ~/project && vim' C-m
tmux new-window -t $SESSION:1 -n 'Server'
tmux send-keys -t 'Server' 'cd ~/project && npm run dev' C-m
tmux new-window -t $SESSION:2 -n 'Git'
tmux send-keys -t 'Git' 'cd ~/project && git status' C-m
# Split panes
tmux select-window -t $SESSION:2
tmux split-window -h
tmux send-keys -t 1 'cd ~/project' C-m
fi
# Attach to session
tmux attach-session -t $SESSION:0
EOF
chmod +x ~/tmux-dev.sh
Command Mode
# Enter command mode
Ctrl+b :
# Common commands
:new-window -n mywindow
:kill-window
:split-window -h
:resize-pane -D 10
:setw synchronize-panes on
:set mouse on
:source-file ~/.tmux.conf
:list-keys
:list-commands
Scripting tmux
Create Complex Layouts
#!/bin/bash
# Create session with specific layout
tmux new-session -d -s complex
# Split into 4 panes
tmux split-window -h -t complex
tmux split-window -v -t complex:0.0
tmux split-window -v -t complex:0.2
# Send commands to each pane
tmux send-keys -t complex:0.0 'htop' C-m
tmux send-keys -t complex:0.1 'tail -f /var/log/syslog' C-m
tmux send-keys -t complex:0.2 'vim' C-m
tmux send-keys -t complex:0.3 'echo "Ready for commands"' C-m
# Attach to session
tmux attach -t complex
Automation Script
#!/bin/bash
# Monitor multiple servers
SERVERS=("server1" "server2" "server3")
SESSION="monitoring"
tmux new-session -d -s $SESSION
for i in "${!SERVERS[@]}"; do
if [ $i -eq 0 ]; then
tmux rename-window -t $SESSION:0 "${SERVERS[$i]}"
else
tmux new-window -t $SESSION:$i -n "${SERVERS[$i]}"
fi
tmux send-keys -t $SESSION:$i "ssh ${SERVERS[$i]}" C-m
done
tmux select-window -t $SESSION:0
tmux attach -t $SESSION
Best Practices
Recommended .tmux.conf Settings
# Essential settings
set -g mouse on # Enable mouse
set -g history-limit 50000 # Large scrollback
set -sg escape-time 0 # No escape delay
set -g base-index 1 # Start windows at 1
setw -g pane-base-index 1 # Start panes at 1
set -g renumber-windows on # Renumber windows
# Visual settings
set -g default-terminal "screen-256color"
set -g status-position bottom
setw -g monitor-activity on
# Key bindings
bind r source-file ~/.tmux.conf \; display "Reloaded!"
bind | split-window -h -c "#{pane_current_path}"
bind - split-window -v -c "#{pane_current_path}"
setw -g mode-keys vi
Workflow Tips
- Use named sessions for different projects
- Create restore scripts for complex layouts
- Enable mouse support for easier navigation
- Use vi key bindings in copy mode
- Set up custom key bindings for frequent actions
- Use tmux with SSH for persistent remote sessions
- Share sessions for collaboration
- Create aliases for common commands
Useful Aliases
# Add to ~/.bashrc or ~/.zshrc
alias tm='tmux'
alias tma='tmux attach -t'
alias tms='tmux new-session -s'
alias tml='tmux list-sessions'
alias tmk='tmux kill-session -t'
Troubleshooting
Common Issues
# Prefix key not working
# Check if prefix is correct in .tmux.conf
tmux show-options -g | grep prefix
# Colors not displaying correctly
set -g default-terminal "screen-256color"
# Mouse not working
set -g mouse on
# Sessions not persisting
# Make sure you detach (Ctrl+b d) instead of exiting
# Can't attach to session
# Check if session exists
tmux ls
# Configuration not loading
# Reload config
tmux source-file ~/.tmux.conf
# Reset tmux to defaults
tmux kill-server
rm ~/.tmux.conf
Debug Mode
# Start tmux in verbose mode
tmux -v
# Show current settings
tmux show-options -g
tmux show-window-options -g
# Check key bindings
tmux list-keys
# Show messages
Ctrl+b ~
Integration with Tools
Vim Integration
# Add to .vimrc for seamless navigation
if exists('$TMUX')
" Use same keybindings for vim and tmux
let g:tmux_navigator_no_mappings = 1
endif
Shell Integration
# Auto-attach or create session
if command -v tmux &> /dev/null && [ -z "$TMUX" ]; then
tmux attach -t default || tmux new -s default
fi
Quick Reference
| Command | Description |
|---|---|
tmux | Start new session |
tmux new -s name | Start named session |
tmux ls | List sessions |
tmux attach -t name | Attach to session |
Ctrl+b d | Detach from session |
Ctrl+b c | Create window |
Ctrl+b , | Rename window |
Ctrl+b % | Split vertically |
Ctrl+b " | Split horizontally |
Ctrl+b Arrow | Navigate panes |
Ctrl+b z | Zoom pane |
Ctrl+b [ | Copy mode |
Ctrl+b ? | List keybindings |
tmux is an essential tool for managing terminal workflows, especially valuable for remote server management, development environments, and maintaining persistent sessions.
Vim
Vim is a powerful text editor that is widely used in the Unix and Linux communities. It is a modal editor, meaning that it has two modes: normal mode and insert mode. In normal mode, you can navigate the text using the arrow keys or the mouse, and you can use commands to manipulate the text. In insert mode, you can type text as you would in any other text editor.
Commonly Used Vim Key Combinations
-
Switching Modes:
i: Enter insert modeEsc: Return to normal mode
-
Navigation:
h: Move leftj: Move downk: Move upl: Move rightw: Jump to the start of the next wordb: Jump to the start of the previous word0: Jump to the beginning of the line$: Jump to the end of the linegg: Go to the top of the fileG: Go to the bottom of the file
-
Editing:
x: Delete the character under the cursordd: Delete the current lineyy: Yank (copy) the current linep: Paste the yanked or deleted text after the cursoru: Undo the last changeCtrl + r: Redo the last undone changer: Replace the character under the cursor
-
Searching:
/pattern: Search for a patternn: Repeat the search in the same directionN: Repeat the search in the opposite direction
-
Saving and Exiting:
:w: Save the file:q: Quit Vim:wq: Save the file and quit Vim:q!: Quit Vim without saving
-
Visual Mode:
v: Enter visual mode to select textV: Enter visual line mode to select whole linesCtrl + v: Enter visual block mode to select a block of text
These key combinations cover a variety of common tasks in Vim, making it a versatile and efficient text editor.
cscope
cscope is a developer's tool for browsing source code in a terminal environment. It's particularly useful for navigating large C codebases, allowing you to search for symbols, function calls, and definitions interactively.
Overview
cscope builds a symbol database from source files and provides a text-based interface for code navigation. While originally designed for C, it also supports C++ and Java.
Key Features:
- Find function definitions and calls
- Search for symbols, assignments, and regular expressions
- Navigate to files containing specific text
- Interactive text-based interface
- Integration with text editors (Vim, Emacs)
- Cross-reference capabilities
Installation
# Ubuntu/Debian
sudo apt update
sudo apt install cscope
# macOS
brew install cscope
# CentOS/RHEL
sudo yum install cscope
# Arch Linux
sudo pacman -S cscope
# Verify installation
cscope -V
Basic Usage
Building Database
# Build database from current directory
cscope -b
# Build database recursively
cscope -bR
# Build database from specific files
cscope -b file1.c file2.c file3.c
# Build from file list
find . -name "*.c" -o -name "*.h" > cscope.files
cscope -b
# Build without launching interface
cscope -b -q # -q for faster database
# Update existing database
cscope -u -b
Interactive Mode
# Launch cscope
cscope
# Launch with specific database
cscope -d # Use existing database (don't rebuild)
# Launch recursively
cscope -R
# Launch in line-oriented mode
cscope -l
Interactive Commands
# In cscope interface:
Tab - Toggle between input field and results
Ctrl+D - Exit cscope
Ctrl+P - Navigate to previous result
Ctrl+N - Navigate to next result
Enter - View selected result
Space - Display next page of results
1-9 - Edit file at result number
# Search types:
0 - Find this C symbol
1 - Find this global definition
2 - Find functions called by this function
3 - Find functions calling this function
4 - Find this text string
5 - Change this text string (grep pattern)
6 - Find this egrep pattern
7 - Find this file
8 - Find files #including this file
9 - Find assignments to this symbol
Command Line Searches
# Find symbol
cscope -L0 symbol_name
# Find global definition
cscope -L1 function_name
# Find functions called by function
cscope -L2 function_name
# Find functions calling function
cscope -L3 function_name
# Find text string
cscope -L4 "error message"
# Find egrep pattern
cscope -L6 "struct.*{$"
# Find file
cscope -L7 filename.c
# Find files including header
cscope -L8 header.h
# Output to file
cscope -L0 main > results.txt
Vim Integration
Basic Setup
" Add to ~/.vimrc
if has("cscope")
set csprg=/usr/bin/cscope
set csto=0
set cst
set nocsverb
" Add cscope database if it exists
if filereadable("cscope.out")
cs add cscope.out
endif
set csverb
endif
Advanced Vim Configuration
" ~/.vimrc
if has("cscope")
set csprg=/usr/bin/cscope
set csto=0
set cst
set csverb
" Load database
if filereadable("cscope.out")
cs add cscope.out
elseif $CSCOPE_DB != ""
cs add $CSCOPE_DB
endif
" Key mappings
nmap <C-\>s :cs find s <C-R>=expand("<cword>")<CR><CR>
nmap <C-\>g :cs find g <C-R>=expand("<cword>")<CR><CR>
nmap <C-\>c :cs find c <C-R>=expand("<cword>")<CR><CR>
nmap <C-\>t :cs find t <C-R>=expand("<cword>")<CR><CR>
nmap <C-\>e :cs find e <C-R>=expand("<cword>")<CR><CR>
nmap <C-\>f :cs find f <C-R>=expand("<cfile>")<CR><CR>
nmap <C-\>i :cs find i ^<C-R>=expand("<cfile>")<CR>$<CR>
nmap <C-\>d :cs find d <C-R>=expand("<cword>")<CR><CR>
" Horizontal split
nmap <C-@>s :scs find s <C-R>=expand("<cword>")<CR><CR>
nmap <C-@>g :scs find g <C-R>=expand("<cword>")<CR><CR>
nmap <C-@>c :scs find c <C-R>=expand("<cword>")<CR><CR>
endif
" Auto-rebuild cscope database
function! UpdateCscope()
silent !cscope -Rb
cs reset
endfunction
command! Cscope call UpdateCscope()
Vim Commands
" In Vim:
:cs find s symbol " Find symbol
:cs find g definition " Find global definition
:cs find c function " Find calls to function
:cs find t text " Find text
:cs find e pattern " Find egrep pattern
:cs find f file " Find file
:cs find i file " Find files #including file
:cs find d symbol " Find functions called by symbol
" Show cscope connections
:cs show
" Reset cscope connections
:cs reset
" Kill cscope connection
:cs kill 0
Advanced Usage
Custom File Lists
# C/C++ project
find . \( -name "*.c" -o -name "*.h" -o -name "*.cpp" -o -name "*.hpp" \) > cscope.files
cscope -b -q
# Exclude directories
find . -path "./build" -prune -o -name "*.c" -print > cscope.files
# Include specific directories only
find src include -name "*.[ch]" > cscope.files
cscope -b -q
Kernel-style Setup
# Linux kernel style
cat << 'EOF' > build_cscope.sh
#!/bin/bash
LNX=/path/to/linux/source
find $LNX \
-path "$LNX/arch/*" ! -path "$LNX/arch/x86*" -prune -o \
-path "$LNX/tmp*" -prune -o \
-path "$LNX/Documentation*" -prune -o \
-path "$LNX/scripts*" -prune -o \
-type f \( -name '*.[chxsS]' -o -name 'Makefile' \) \
-print > cscope.files
cscope -b -q -k
EOF
chmod +x build_cscope.sh
./build_cscope.sh
Multiple Projects
# Project 1
cd /project1
cscope -b -q
export CSCOPE_DB=/project1/cscope.out
# Project 2 (separate database)
cd /project2
cscope -b -q -f cscope_proj2.out
# Use in Vim
:cs add /project1/cscope.out /project1
:cs add /project2/cscope_proj2.out /project2
Scripting with cscope
Automated Searches
#!/bin/bash
# find_function_calls.sh
FUNC=$1
if [ -z "$FUNC" ]; then
echo "Usage: $0 <function_name>"
exit 1
fi
echo "Functions calling $FUNC:"
cscope -dL3 $FUNC
echo ""
echo "Functions called by $FUNC:"
cscope -dL2 $FUNC
Generate Call Graph
#!/bin/bash
# Generate simple call graph
FUNC=$1
function recurse_calls() {
local func=$1
local indent=$2
echo "${indent}${func}"
# Find functions called by this function
cscope -dL2 "$func" | while read line; do
called=$(echo $line | awk '{print $2}')
if [ ! -z "$called" ]; then
recurse_calls "$called" "${indent} "
fi
done
}
recurse_calls "$FUNC" ""
Find Unused Functions
#!/bin/bash
# find_unused.sh
# Get all function definitions
cscope -dL1 "" | awk '{print $2}' | sort -u > /tmp/all_funcs.txt
# For each function, check if it's called
while read func; do
if [ "$func" != "main" ]; then
calls=$(cscope -dL3 "$func" | wc -l)
if [ $calls -eq 0 ]; then
echo "Unused: $func"
fi
fi
done < /tmp/all_funcs.txt
rm /tmp/all_funcs.txt
Makefile Integration
# Add to Makefile
.PHONY: cscope
cscope:
@find . -name "*.[ch]" > cscope.files
@cscope -b -q
.PHONY: cscope-clean
cscope-clean:
@rm -f cscope.* cscope.files
.PHONY: cscope-update
cscope-update: cscope-clean cscope
Configuration File
# ~/.cscoperc or project .cscoperc
# (cscope automatically loads this)
# Custom options (limited support)
# Most configuration done via command line
Emacs Integration
;; Add to ~/.emacs or ~/.emacs.d/init.el
(require 'xcscope)
(cscope-setup)
;; Key bindings
(define-key global-map [(control f3)] 'cscope-set-initial-directory)
(define-key global-map [(control f4)] 'cscope-find-this-symbol)
(define-key global-map [(control f5)] 'cscope-find-global-definition)
(define-key global-map [(control f6)] 'cscope-find-functions-calling-this-function)
(define-key global-map [(control f7)] 'cscope-find-called-functions)
(define-key global-map [(control f8)] 'cscope-find-this-text-string)
(define-key global-map [(control f9)] 'cscope-find-this-file)
(define-key global-map [(control f10)] 'cscope-find-files-including-file)
;; Auto-update database
(setq cscope-do-not-update-database nil)
Best Practices
Large Projects
# Build inverted index for faster searches
cscope -b -q
# Use compression for large databases
cscope -b -c
# Incremental updates
cscope -u -b -q
# Index only relevant files
find . -name "*.[ch]" \
! -path "*/test/*" \
! -path "*/build/*" \
> cscope.files
cscope -b -q
Project Setup Script
#!/bin/bash
# setup_cscope.sh
PROJECT_ROOT=$(git rev-parse --show-toplevel 2>/dev/null || pwd)
cd "$PROJECT_ROOT"
echo "Building cscope database for: $PROJECT_ROOT"
# Find relevant source files
find . \( -name "*.c" -o -name "*.h" -o -name "*.cpp" -o -name "*.hpp" -o -name "*.cc" \) \
! -path "*/build/*" \
! -path "*/\.git/*" \
! -path "*/node_modules/*" \
> cscope.files
# Build database with inverted index
cscope -b -q -k
echo "Database built: cscope.out"
echo ""
echo "Usage:"
echo " cscope -d # Launch interactive mode"
echo " vim <file> # Use with Vim (if configured)"
echo " cscope -L0 symbol # Command-line search"
Automatic Rebuilds
# Add to project root
# .git/hooks/post-commit
#!/bin/bash
echo "Rebuilding cscope database..."
cscope -b -q -k
echo "Done"
Common Patterns
Search All Files
# Find all occurrences of a string
cscope -L4 "TODO"
# Find all error messages
cscope -L4 "error:"
# Find struct definitions
cscope -L6 "^struct"
# Find all malloc calls
cscope -L0 malloc
Code Review
# Find all functions modified in recent commit
git diff --name-only HEAD~1 | grep '\.[ch]$' | while read file; do
echo "=== $file ==="
# Get function names from file
ctags -x --c-kinds=f "$file" | awk '{print $1}'
done
Troubleshooting
# Database not found
cscope -b -R # Rebuild recursively
# Incomplete results
rm cscope.out*
cscope -b -q # Rebuild with index
# Vim integration not working
:cs show # Check connections
:cs reset # Reset connections
:cs add cscope.out
# Permission denied
chmod 644 cscope.out*
# Slow searches
cscope -b -q # Build with inverted index
# Wrong directory
export CSCOPE_DB=/path/to/cscope.out
Quick Reference
| Command | Description |
|---|---|
cscope -b | Build database |
cscope -R | Recursive search |
cscope -d | Use existing database |
cscope -u | Update database |
cscope -q | Build inverted index |
cscope -L0 | Find symbol |
cscope -L1 | Find definition |
cscope -L3 | Find callers |
:cs find s | Vim: Find symbol |
:cs find g | Vim: Find definition |
cscope is an essential tool for navigating large C codebases, providing fast symbol lookups and cross-references that make code exploration and maintenance significantly easier.
ctags
ctags is a tool that generates an index (or "tag") file of names found in source and header files, enabling efficient code navigation in text editors. It supports numerous programming languages and integrates seamlessly with Vim, Emacs, and other editors.
Overview
ctags creates a database of language objects (functions, classes, variables, etc.) found in source files, allowing editors to quickly jump to definitions. Modern implementations include Exuberant Ctags and Universal Ctags.
Key Features:
- Multi-language support (C, C++, Python, Java, JavaScript, etc.)
- Editor integration (Vim, Emacs, Sublime, VS Code)
- Recursive directory scanning
- Custom tag patterns
- Symbol cross-referencing
- Incremental updates
Installation
# Ubuntu/Debian - Universal Ctags (recommended)
sudo apt update
sudo apt install universal-ctags
# Or Exuberant Ctags (older)
sudo apt install exuberant-ctags
# macOS - Universal Ctags
brew install --HEAD universal-ctags/universal-ctags/universal-ctags
# CentOS/RHEL
sudo yum install ctags
# Arch Linux
sudo pacman -S ctags
# From source (Universal Ctags)
git clone https://github.com/universal-ctags/ctags.git
cd ctags
./autogen.sh
./configure
make
sudo make install
# Verify installation
ctags --version
Basic Usage
Generating Tags
# Generate tags for current directory
ctags *
# Recursive tag generation
ctags -R
# Specific files
ctags file1.c file2.c file3.h
# Multiple languages
ctags -R src/ include/
# Generate tags for specific language
ctags -R --languages=C,C++
# Exclude languages
ctags -R --languages=-JavaScript,-HTML
# Follow symbolic links
ctags -R --links=yes
Tag File Options
# Specify output file
ctags -o mytags -R
# Append to existing tags
ctags -a -R new_directory/
# Create tag file with extra information
ctags -R --fields=+iaS --extras=+q
# Sort tags file
ctags -R --sort=yes
# Case-insensitive sorting
ctags -R --sort=foldcase
Vim Integration
Basic Configuration
" Add to ~/.vimrc
set tags=./tags,tags;$HOME
" Search for tags file in current directory and up to $HOME
set tags=./tags;/
Vim Commands
" Jump to definition
Ctrl+] " Jump to tag under cursor
g Ctrl+] " Show list if multiple matches
" Return from jump
Ctrl+T " Jump back (pop tag stack)
Ctrl+O " Jump to previous location
" Navigation
:tag function " Jump to tag
:ts pattern " List matching tags
:tn " Next matching tag
:tp " Previous matching tag
" Tag stack
:tags " Show tag stack
:pop " Pop from tag stack
" Split window navigation
Ctrl+W ] " Split window and jump to tag
Ctrl+W g ] " Split and list matches
Advanced Vim Configuration
" ~/.vimrc
" Set tags file locations
set tags=./tags,tags;$HOME
" Enable tag stack
set tagstack
" Show tag preview in popup
set completeopt=menuone,preview
" Custom key mappings
nnoremap <C-]> g<C-]> " Always show list if multiple matches
nnoremap <leader>t :tag<Space>
nnoremap <leader>] :tselect<CR>
nnoremap <leader>[ :pop<CR>
" Split navigation
nnoremap <C-\> :tab split<CR>:exec("tag ".expand("<cword>"))<CR>
nnoremap <A-]> :vsp <CR>:exec("tag ".expand("<cword>"))<CR>
" Auto-regenerate tags
autocmd BufWritePost *.c,*.cpp,*.h,*.py silent! !ctags -R &
Vim with Tagbar Plugin
" Install with vim-plug
Plug 'majutsushi/tagbar'
" Configuration
nmap <F8> :TagbarToggle<CR>
let g:tagbar_width = 30
let g:tagbar_autofocus = 1
let g:tagbar_sort = 0
" Custom language configuration
let g:tagbar_type_go = {
\ 'ctagstype' : 'go',
\ 'kinds' : [
\ 'p:package',
\ 'i:imports',
\ 'c:constants',
\ 'v:variables',
\ 't:types',
\ 'n:interfaces',
\ 'w:fields',
\ 'e:embedded',
\ 'm:methods',
\ 'r:constructor',
\ 'f:functions'
\ ],
\ 'sro' : '.',
\ 'kind2scope' : {
\ 't' : 'ctype',
\ 'n' : 'ntype'
\ },
\ 'scope2kind' : {
\ 'ctype' : 't',
\ 'ntype' : 'n'
\ },
\ }
Language-Specific Features
C/C++
# C/C++ with all features
ctags -R \
--c-kinds=+p \
--c++-kinds=+p \
--fields=+iaS \
--extras=+q
# Include system headers
ctags -R --c-kinds=+px --fields=+iaS --extras=+q \
/usr/include \
/usr/local/include \
.
# Kernel-style projects
ctags -R \
--exclude=.git \
--exclude=build \
--exclude=Documentation \
--languages=C \
--langmap=c:.c.h \
--c-kinds=+px \
--fields=+iaS \
--extras=+q
Python
# Python projects
ctags -R \
--languages=Python \
--python-kinds=-i \
--fields=+l
# Include virtualenv
ctags -R \
--languages=Python \
--fields=+l \
. \
venv/lib/python*/site-packages/
JavaScript/TypeScript
# JavaScript
ctags -R \
--languages=JavaScript \
--exclude=node_modules \
--exclude=dist \
--exclude=build
# TypeScript
ctags -R \
--languages=TypeScript \
--exclude=node_modules \
--exclude=*.min.js
Java
# Java projects
ctags -R \
--languages=Java \
--exclude=.git \
--exclude=target \
--exclude=*.class
# Include JAR dependencies (if unpacked)
ctags -R src/ lib/
Advanced Usage
Custom Configuration
# ~/.ctags.d/local.ctags (Universal Ctags)
--recurse=yes
--tag-relative=yes
--exclude=.git
--exclude=.svn
--exclude=.hg
--exclude=node_modules
--exclude=bower_components
--exclude=*.min.js
--exclude=*.swp
--exclude=*.bak
--exclude=*.pyc
--exclude=*.class
--exclude=target
--exclude=build
--exclude=dist
# Language-specific
--langdef=markdown
--langmap=markdown:.md.markdown.mdown.mkd.mkdn
--regex-markdown=/^#{1}[ \t]+(.+)/. \1/h,heading1/
--regex-markdown=/^#{2}[ \t]+(.+)/.. \1/h,heading2/
--regex-markdown=/^#{3}[ \t]+(.+)/... \1/h,heading3/
Project-Specific Tags
# .git/hooks/post-commit
#!/bin/bash
ctags -R &
# Make executable
chmod +x .git/hooks/post-commit
# Or use Makefile
.PHONY: tags
tags:
ctags -R --fields=+iaS --extras=+q
.PHONY: tags-clean
tags-clean:
rm -f tags
Filtering and Exclusions
# Exclude directories
ctags -R --exclude=build --exclude=.git --exclude=node_modules
# Exclude files by pattern
ctags -R --exclude=*.min.js --exclude=*.test.js
# Include only specific directories
ctags -R src/ include/
# Custom exclusions file
echo "build/" > .ctagsignore
echo "*.min.js" >> .ctagsignore
ctags -R --exclude=@.ctagsignore
Scripting with ctags
Automated Tag Generation
#!/bin/bash
# update_tags.sh
PROJECT_ROOT=$(git rev-parse --show-toplevel 2>/dev/null || pwd)
cd "$PROJECT_ROOT"
echo "Generating tags for: $PROJECT_ROOT"
ctags -R \
--fields=+iaS \
--extras=+q \
--exclude=.git \
--exclude=build \
--exclude=node_modules \
--exclude=*.min.js
echo "Tags file generated: $PROJECT_ROOT/tags"
Multi-Project Tags
#!/bin/bash
# generate_all_tags.sh
PROJECTS=(
"$HOME/projects/project1"
"$HOME/projects/project2"
"$HOME/projects/lib/common"
)
for project in "${PROJECTS[@]}"; do
if [ -d "$project" ]; then
echo "Generating tags for $project"
(cd "$project" && ctags -R)
fi
done
# Merge tags files
cat ~/projects/*/tags | sort -u > ~/projects/all_tags
Find Symbol Across Projects
#!/bin/bash
# find_symbol.sh
SYMBOL=$1
if [ -z "$SYMBOL" ]; then
echo "Usage: $0 <symbol>"
exit 1
fi
# Search in tags file
echo "Searching for: $SYMBOL"
echo ""
grep "^$SYMBOL" tags | while IFS=$'\t' read tag file pattern rest; do
echo "File: $file"
echo "Pattern: $pattern"
echo "---"
done
Integration with Other Editors
Emacs
;; Add to ~/.emacs or ~/.emacs.d/init.el
;; Enable etags (similar to ctags)
(setq tags-table-list '("./TAGS" "../TAGS" "../../TAGS"))
;; Key bindings
(global-set-key (kbd "M-.") 'find-tag)
(global-set-key (kbd "M-*") 'pop-tag-mark)
(global-set-key (kbd "M-,") 'tags-loop-continue)
;; Generate tags for project
(defun my-generate-tags ()
(interactive)
(shell-command "ctags -e -R ."))
(global-set-key (kbd "C-c g") 'my-generate-tags)
VS Code
// settings.json
{
"ctagsFile": "tags",
"ctagsPath": "/usr/bin/ctags"
}
// Install extension
// ext install jaydenlin.ctags-support
Sublime Text
// Settings - User
{
"tags_path": "tags",
"ctags_command": "/usr/bin/ctags -R --fields=+iaS --extras=+q"
}
// Install CTags package via Package Control
Common Patterns
Monorepo Tag Management
#!/bin/bash
# monorepo_tags.sh
# Root tags
ctags -R --fields=+iaS --extras=+q -o tags.root .
# Per-service tags
for service in services/*; do
if [ -d "$service" ]; then
(cd "$service" && ctags -R -o tags .)
fi
done
# Merge all tags
find . -name "tags" -exec cat {} \; | sort -u > tags
Language-Specific Tag Files
#!/bin/bash
# Generate separate tags for each language
# C/C++ tags
ctags -R -o tags.c --languages=C,C++ .
# Python tags
ctags -R -o tags.py --languages=Python .
# JavaScript tags
ctags -R -o tags.js --languages=JavaScript --exclude=node_modules .
# Merge all
cat tags.* | sort -u > tags
Incremental Updates
#!/bin/bash
# update_changed.sh
# Get changed files since last tag generation
CHANGED=$(find . -type f -newer tags \( -name "*.c" -o -name "*.h" \))
if [ ! -z "$CHANGED" ]; then
echo "Updating tags for changed files"
# Generate tags for changed files
ctags -a $CHANGED
# Sort tags file
sort -u tags -o tags
fi
Best Practices
Recommended Configuration
# ~/.ctags or ~/.ctags.d/default.ctags (Universal Ctags)
# Recurse by default
--recurse=yes
# Tag relative paths
--tag-relative=yes
# Additional fields
--fields=+iaS
--extras=+q
# Common exclusions
--exclude=.git
--exclude=.svn
--exclude=node_modules
--exclude=bower_components
--exclude=*.min.js
--exclude=*.min.css
--exclude=*.map
--exclude=build
--exclude=dist
--exclude=target
--exclude=*.pyc
--exclude=*.class
--exclude=.DS_Store
# Sort tags
--sort=yes
# Language-specific
--languages=all
--c-kinds=+px
--c++-kinds=+px
--python-kinds=-i
Git Integration
# .gitignore
tags
tags.lock
tags.temp
TAGS
# .git/hooks/post-checkout
#!/bin/bash
ctags -R &
# .git/hooks/post-merge
#!/bin/bash
ctags -R &
Performance Tips
# Use parallel processing for large projects
find . -name "*.c" -o -name "*.h" | xargs -P 4 -n 50 ctags -a
# Generate tags in background
ctags -R &
# Use faster sorting
ctags -R --sort=no
LC_ALL=C sort tags -o tags
# Exclude large dependency directories
ctags -R --exclude=vendor --exclude=node_modules
Troubleshooting
# Tags file not found in Vim
:set tags? # Check tags path
:set tags=./tags;/ # Set tags path
# Duplicate entries
sort -u tags -o tags
# Wrong language detected
ctags --list-languages # Show supported languages
ctags --list-maps # Show file extensions
ctags -R --languages=C,C++ # Force specific languages
# Performance issues
ctags -R --exclude=node_modules --exclude=vendor
# Tags not updating
rm tags
ctags -R
# Vim not jumping to correct location
# Regenerate with line numbers
ctags -R --fields=+n
# Check tag format
head -n 20 tags
Quick Reference
| Command | Description |
|---|---|
ctags -R | Generate tags recursively |
ctags -a | Append to tags |
ctags --list-languages | Show supported languages |
Ctrl+] | Vim: Jump to tag |
Ctrl+T | Vim: Return from tag |
:ts | Vim: List tags |
:tag name | Vim: Jump to tag |
--exclude=DIR | Exclude directory |
--languages=LANG | Specific languages |
--fields=+iaS | Extra tag fields |
ctags is an essential tool for code navigation, enabling developers to efficiently explore and understand large codebases by providing instant access to symbol definitions and references.
mdBook
mdBook is a command-line tool for creating books from Markdown files, similar to Gitbook but implemented in Rust. It's fast, simple, and ideal for technical documentation, tutorials, and books.
Overview
mdBook takes Markdown files and generates a static website with built-in search, syntax highlighting, and theme support. It's the tool used to create the official Rust programming language book.
Key Features:
- Fast static site generation
- Automatic table of contents
- Built-in search functionality
- Syntax highlighting for code
- Light and dark themes
- Live preview with hot reloading
- Markdown extensions
- Customizable with preprocessors
Installation
# Using Cargo (Rust package manager)
cargo install mdbook
# Ubuntu/Debian (from binary)
wget https://github.com/rust-lang/mdBook/releases/download/v0.4.36/mdbook-v0.4.36-x86_64-unknown-linux-gnu.tar.gz
tar xzf mdbook-v0.4.36-x86_64-unknown-linux-gnu.tar.gz
sudo mv mdbook /usr/local/bin/
# macOS
brew install mdbook
# From source
git clone https://github.com/rust-lang/mdBook.git
cd mdBook
cargo build --release
sudo cp target/release/mdbook /usr/local/bin/
# Verify installation
mdbook --version
Quick Start
Create a New Book
# Create new book
mdbook init mybook
# Project structure created:
# mybook/
# ├── book.toml # Configuration file
# └── src/
# ├── SUMMARY.md # Table of contents
# └── chapter_1.md
# Enter directory
cd mybook
# Build the book
mdbook build
# Serve with live preview
mdbook serve
# Open in browser
open http://localhost:3000
Project Structure
mybook/
├── book.toml # Configuration
├── src/
│ ├── SUMMARY.md # Table of contents (required)
│ ├── chapter_1.md # Chapter files
│ ├── chapter_2.md
│ ├── images/ # Images directory
│ │ └── diagram.png
│ └── sub_chapter/
│ └── section.md
└── book/ # Generated output (git ignore)
├── index.html
├── chapter_1.html
└── ...
Configuration
Basic book.toml
[book]
title = "My Amazing Book"
authors = ["John Doe"]
language = "en"
multilingual = false
src = "src"
[build]
build-dir = "book"
create-missing = true
[output.html]
default-theme = "light"
preferred-dark-theme = "navy"
git-repository-url = "https://github.com/user/repo"
git-repository-icon = "fa-github"
Advanced Configuration
[book]
title = "Advanced Guide"
authors = ["Jane Smith", "John Doe"]
description = "A comprehensive guide"
language = "en"
multilingual = false
src = "src"
[build]
build-dir = "book"
create-missing = true
[preprocessor.links]
[output.html]
# Theme
default-theme = "rust"
preferred-dark-theme = "navy"
curly-quotes = true
# Repository
git-repository-url = "https://github.com/user/repo"
git-repository-icon = "fa-github"
# Navigation
additional-css = ["custom.css"]
additional-js = ["custom.js"]
# Code
no-section-label = false
# Search
[output.html.search]
enable = true
limit-results = 30
teaser-word-count = 30
use-boolean-and = true
boost-title = 2
boost-hierarchy = 1
boost-paragraph = 1
expand = true
heading-split-level = 3
# Print
[output.html.print]
enable = true
# Playground (for Rust code)
[output.html.playground]
editable = true
copyable = true
copy-js = true
line-numbers = false
runnable = true
SUMMARY.md Format
Basic Structure
# Summary
[Introduction](./introduction.md)
# User Guide
- [Getting Started](./guide/getting-started.md)
- [Installation](./guide/installation.md)
- [Linux](./guide/installation/linux.md)
- [macOS](./guide/installation/macos.md)
- [Windows](./guide/installation/windows.md)
- [Configuration](./guide/configuration.md)
# Reference
- [API Reference](./reference/api.md)
- [CLI Commands](./reference/cli.md)
# Appendix
- [Glossary](./appendix/glossary.md)
- [Contributors](./appendix/contributors.md)
Advanced Features
# Summary
[Preface](./preface.md)
---
# Part I: Basics
- [Chapter 1](./chapter-1.md)
- [Chapter 2](./chapter-2.md)
---
# Part II: Advanced
- [Chapter 3](./chapter-3.md)
- [Section 3.1](./chapter-3/section-1.md)
- [Section 3.2](./chapter-3/section-2.md)
---
[Conclusion](./conclusion.md)
[Appendix](./appendix.md)
Commands
Build Commands
# Build book
mdbook build
# Build and watch for changes
mdbook watch
# Serve with live reload
mdbook serve
# Serve on different port
mdbook serve -p 8080
# Serve on specific address
mdbook serve -n 0.0.0.0
# Open in browser
mdbook serve --open
# Build to different directory
mdbook build -d /tmp/mybook
Testing
# Test code examples
mdbook test
# Test with specific library
mdbook test --library-path ./target/debug
# Test specific chapter
mdbook test path/to/chapter.md
Cleaning
# Clean build directory
mdbook clean
# Remove specific build
rm -rf book/
Markdown Extensions
Code Blocks
```rust
fn main() {
println!("Hello, world!");
}
```
```rust,editable
// This code can be edited in browser
fn main() {
println!("Try editing me!");
}
```
```rust,ignore
// This code won't be tested
fn incomplete() {
```
```rust,no_run
// Compiles but doesn't run during tests
fn main() {
std::process::exit(1);
}
```
```rust,should_panic
// Expected to panic
fn main() {
panic!("Expected panic");
}
```
```python
def greet(name):
print(f"Hello, {name}!")
```
```bash
#!/bin/bash
echo "Hello from bash"
```
Include Files
<!-- Include entire file -->
{{#include path/to/file.rs}}
<!-- Include specific lines -->
{{#include path/to/file.rs:10:20}}
<!-- Include from line to end -->
{{#include path/to/file.rs:10:}}
<!-- Include with anchor -->
{{#include path/to/file.rs:my_anchor}}
Rust Playground
```rust,editable
{{#playpen example.rs}}
```
```rust
{{#rustdoc_include path/to/lib.rs}}
```
Customization
Custom CSS
/* custom.css */
:root {
--sidebar-width: 300px;
--page-padding: 20px;
--content-max-width: 900px;
}
.content {
font-size: 18px;
line-height: 1.8;
}
.chapter {
padding: 2em;
}
code {
font-family: 'Fira Code', monospace;
}
pre {
border-radius: 8px;
}
Custom JavaScript
// custom.js
window.addEventListener('load', function() {
// Add custom functionality
console.log('Book loaded');
// Add copy button to code blocks
document.querySelectorAll('pre > code').forEach(function(code) {
const button = document.createElement('button');
button.textContent = 'Copy';
button.onclick = function() {
navigator.clipboard.writeText(code.textContent);
button.textContent = 'Copied!';
setTimeout(() => button.textContent = 'Copy', 2000);
};
code.parentElement.insertBefore(button, code);
});
});
Custom Theme
# book.toml
[output.html]
theme = "my-theme"
# Create theme directory
# mkdir -p my-theme
# Copy and modify default theme files
# Extract default theme
mdbook init --theme
# Files created in theme/:
# - index.hbs # Main template
# - head.hbs # HTML head
# - header.hbs # Page header
# - chrome.css # UI styles
# - general.css # Content styles
# - variables.css # CSS variables
Preprocessors
Built-in Preprocessors
# Enable links preprocessor
[preprocessor.links]
# Example usage in Markdown:
# [Rust](https://www.rust-lang.org/)
Custom Preprocessor
// my-preprocessor/src/main.rs use mdbook::preprocess::{Preprocessor, PreprocessorContext}; use mdbook::book::Book; use std::io; struct MyPreprocessor; impl Preprocessor for MyPreprocessor { fn name(&self) -> &str { "my-preprocessor" } fn run(&self, ctx: &PreprocessorContext, mut book: Book) -> Result<Book, Error> { // Process book content Ok(book) } } fn main() { let preprocessor = MyPreprocessor; if let Err(e) = mdbook::preprocess::handle_preprocessing(&preprocessor) { eprintln!("{}", e); std::process::exit(1); } }
# book.toml
[preprocessor.my-preprocessor]
command = "my-preprocessor"
Deployment
GitHub Pages
# Build book
mdbook build
# Initialize git (if needed)
git init
git add .
git commit -m "Initial commit"
# Create gh-pages branch
git checkout --orphan gh-pages
git reset --hard
cp -r book/* .
rm -rf book src
git add .
git commit -m "Deploy book"
git push origin gh-pages
# Or use GitHub Actions
GitHub Actions Workflow
# .github/workflows/deploy.yml
name: Deploy mdBook
on:
push:
branches: [ main ]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Setup mdBook
uses: peaceiris/actions-mdbook@v1
with:
mdbook-version: 'latest'
- name: Build
run: mdbook build
- name: Deploy
uses: peaceiris/actions-gh-pages@v3
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: ./book
Netlify
# netlify.toml
[build]
command = "mdbook build"
publish = "book"
[build.environment]
RUST_VERSION = "1.70.0"
Docker
# Dockerfile
FROM rust:1.70 as builder
RUN cargo install mdbook
WORKDIR /book
COPY . .
RUN mdbook build
FROM nginx:alpine
COPY --from=builder /book/book /usr/share/nginx/html
Common Patterns
Multi-Language Book
# book.toml
[book]
multilingual = true
[output.html]
redirect = { "/" = "/en/" }
# Directory structure:
# src/
# ├── en/
# │ ├── SUMMARY.md
# │ └── chapter_1.md
# └── es/
# ├── SUMMARY.md
# └── chapter_1.md
Code Examples Project
<!-- Link to example project -->
See the [full example](https://github.com/user/repo/tree/main/examples/basic)
<!-- Include code from example -->
{{#include ../../examples/basic/src/main.rs}}
Versioned Documentation
#!/bin/bash
# build_versions.sh
VERSIONS=("v1.0" "v1.1" "v2.0")
for version in "${VERSIONS[@]}"; do
git checkout $version
mdbook build -d "book/$version"
done
# Create index.html for version selection
Best Practices
Content Organization
# Recommended structure:
src/
├── SUMMARY.md
├── introduction.md
├── guide/
│ ├── README.md # Chapter intro
│ ├── basics.md
│ └── advanced.md
├── reference/
│ ├── README.md
│ ├── api.md
│ └── cli.md
├── examples/
│ └── tutorial.md
└── appendix/
├── glossary.md
└── resources.md
Markdown Style
# Use consistent heading levels
## Chapter Title
### Section
#### Subsection
# Use relative links
[Link to other chapter](../other/chapter.md)
# Use descriptive alt text for images

# Include language in code blocks
```rust
fn main() {}
Use admonitions (with appropriate CSS)
Note: Important information
Warning: Be careful here
### Performance Tips
```bash
# Minimize preprocessors
# Use relative links
# Optimize images
# Enable search caching
[output.html.search]
limit-results = 20
Troubleshooting
# Build fails
mdbook build -v # Verbose output
# Links not working
# Use relative links: ./file.md or ../other/file.md
# Search not working
[output.html.search]
enable = true
# Changes not reflecting
mdbook clean && mdbook build
# Port already in use
mdbook serve -p 3001
# Code not highlighting
# Ensure language is specified in code blocks
Quick Reference
| Command | Description |
|---|---|
mdbook init | Create new book |
mdbook build | Build book |
mdbook serve | Serve with live reload |
mdbook test | Test code examples |
mdbook clean | Clean build directory |
mdbook watch | Watch for changes |
mdBook is an excellent tool for creating beautiful, fast, and maintainable documentation, perfect for technical books, tutorials, API documentation, and user guides.
sed
sed is a stream editor that allows you to manipulate text files. It is a powerful tool that can be used to search for and replace text, delete text, and more.
Commonly Used sed Commands
-
Replace all occurrences of a string in a file:
sed -i 's/old_string/new_string/g' filename -
Delete lines containing a specific pattern:
sed -i '/pattern/d' filename -
Print only lines that match a pattern:
sed -n '/pattern/p' filename -
Insert a line after a match:
sed -i '/pattern/a\new_line' filename -
Insert a line before a match:
sed -i '/pattern/i\new_line' filename -
Replace text on a specific line number:
sed -i '3s/old_text/new_text/' filename -
Delete a specific line number:
sed -i '5d' filename -
Replace text between two patterns:
sed -i '/start_pattern/,/end_pattern/s/old_text/new_text/g' filename -
Print lines between two patterns:
sed -n '/start_pattern/,/end_pattern/p' filename -
Append text to the end of each line:
sed -i 's/$/append_text/' filename
These commands cover a variety of common use cases for the sed command, making it a versatile tool for text manipulation and processing.
awk
awk is a powerful tool for processing text files. It is a powerful tool that can be used to process text files, extract data, and more.
Commonly Used awk Commands
-
Print all lines in a file:
awk '{print}' filename -
Print the first column of a file:
awk '{print $1}' filename -
Print the first and second columns of a file:
awk '{print $1, $2}' filename -
Print lines that match a pattern:
awk '/pattern/ {print}' filename -
Print lines where the value in the first column is greater than 10:
awk '$1 > 10' filename -
Calculate the sum of values in the first column:
awk '{sum += $1} END {print sum}' filename -
Calculate the average of values in the first column:
awk '{sum += $1; count++} END {print sum/count}' filename -
Print the number of lines in a file:
awk 'END {print NR}' filename -
Print lines with more than 3 fields:
awk 'NF > 3' filename -
Replace a string in a file:
awk '{gsub(/old_string/, "new_string"); print}' filename
These commands cover a variety of common use cases for the awk command, making it a versatile tool for text processing and data extraction.
curl
curl (Client URL) is a command-line tool and library for transferring data with URLs. It supports a wide range of protocols including HTTP, HTTPS, FTP, FTPS, SCP, SFTP, TFTP, and more.
Overview
curl is one of the most versatile tools for testing APIs, downloading files, and debugging network requests. It's available on virtually all platforms and is commonly used in scripts and automation.
Key Features:
- Support for numerous protocols (HTTP, HTTPS, FTP, SMTP, etc.)
- Authentication support (Basic, Digest, OAuth, etc.)
- SSL/TLS support
- Cookie handling
- Resume transfers
- Proxy support
- Rate limiting
- Custom headers and methods
Basic Usage
Simple GET Request
# Basic GET request
curl https://api.example.com
# GET with output to file
curl https://example.com -o output.html
curl https://example.com --output output.html
# Save with original filename
curl -O https://example.com/file.pdf
# Follow redirects
curl -L https://shortened-url.com
Viewing Response Details
# Show response headers only
curl -I https://api.example.com
curl --head https://api.example.com
# Show response headers and body
curl -i https://api.example.com
curl --include https://api.example.com
# Verbose output (shows request/response details)
curl -v https://api.example.com
curl --verbose https://api.example.com
# Show only HTTP status code
curl -o /dev/null -s -w "%{http_code}\n" https://api.example.com
HTTP Methods
GET Request
# GET with query parameters
curl "https://api.example.com/users?page=1&limit=10"
# GET with URL-encoded parameters
curl -G https://api.example.com/search \
-d "query=curl tutorial" \
-d "limit=5"
POST Request
# POST with form data
curl -X POST https://api.example.com/users \
-d "name=John" \
-d "email=john@example.com"
# POST with JSON data
curl -X POST https://api.example.com/users \
-H "Content-Type: application/json" \
-d '{"name":"John","email":"john@example.com"}'
# POST with JSON from file
curl -X POST https://api.example.com/users \
-H "Content-Type: application/json" \
-d @data.json
# POST with form file upload
curl -X POST https://api.example.com/upload \
-F "file=@document.pdf" \
-F "description=My document"
PUT Request
# PUT to update resource
curl -X PUT https://api.example.com/users/123 \
-H "Content-Type: application/json" \
-d '{"name":"John Updated","email":"john.new@example.com"}'
PATCH Request
# PATCH to partially update resource
curl -X PATCH https://api.example.com/users/123 \
-H "Content-Type: application/json" \
-d '{"email":"newemail@example.com"}'
DELETE Request
# DELETE a resource
curl -X DELETE https://api.example.com/users/123
# DELETE with authentication
curl -X DELETE https://api.example.com/users/123 \
-H "Authorization: Bearer token123"
Headers
Custom Headers
# Single custom header
curl -H "X-Custom-Header: value" https://api.example.com
# Multiple headers
curl -H "Content-Type: application/json" \
-H "Authorization: Bearer token123" \
-H "X-Request-ID: abc123" \
https://api.example.com
# User-Agent header
curl -A "MyApp/1.0" https://api.example.com
curl --user-agent "MyApp/1.0" https://api.example.com
# Referer header
curl -e "https://referrer.com" https://api.example.com
curl --referer "https://referrer.com" https://api.example.com
Accept Headers
# Request JSON response
curl -H "Accept: application/json" https://api.example.com
# Request XML response
curl -H "Accept: application/xml" https://api.example.com
# Request specific API version
curl -H "Accept: application/vnd.api+json; version=2" https://api.example.com
Authentication
Basic Authentication
# Basic auth (username:password)
curl -u username:password https://api.example.com
# Basic auth with prompt for password
curl -u username https://api.example.com
# Basic auth in URL (not recommended for production)
curl https://username:password@api.example.com
Bearer Token
# Bearer token authentication
curl -H "Authorization: Bearer your_token_here" https://api.example.com
# Using environment variable
export TOKEN="your_token_here"
curl -H "Authorization: Bearer $TOKEN" https://api.example.com
API Key
# API key in header
curl -H "X-API-Key: your_api_key" https://api.example.com
# API key in query parameter
curl "https://api.example.com/data?api_key=your_api_key"
OAuth 2.0
# OAuth 2.0 with access token
curl -H "Authorization: Bearer access_token" https://api.example.com
# Get OAuth token
curl -X POST https://auth.example.com/token \
-d "grant_type=client_credentials" \
-d "client_id=your_client_id" \
-d "client_secret=your_client_secret"
Cookies
Managing Cookies
# Save cookies to file
curl -c cookies.txt https://example.com/login \
-d "username=user&password=pass"
# Load cookies from file
curl -b cookies.txt https://example.com/profile
# Send cookies directly
curl -b "session=abc123; user=john" https://example.com
# Save and load cookies in same request
curl -b cookies.txt -c cookies.txt https://example.com
File Operations
Downloading Files
# Download single file
curl -O https://example.com/file.zip
# Download with custom name
curl -o myfile.zip https://example.com/file.zip
# Download multiple files
curl -O https://example.com/file1.zip \
-O https://example.com/file2.zip
# Resume interrupted download
curl -C - -O https://example.com/largefile.zip
# Download with progress bar
curl -# -O https://example.com/file.zip
Uploading Files
# Upload file with PUT
curl -X PUT https://api.example.com/files/document.pdf \
--upload-file document.pdf
# Upload with POST multipart
curl -F "file=@document.pdf" https://api.example.com/upload
# Upload multiple files
curl -F "file1=@doc1.pdf" \
-F "file2=@doc2.pdf" \
https://api.example.com/upload
FTP Operations
# Download from FTP
curl ftp://ftp.example.com/file.txt -u username:password
# Upload to FTP
curl -T localfile.txt ftp://ftp.example.com/ -u username:password
# List FTP directory
curl ftp://ftp.example.com/ -u username:password
Advanced Options
Timeouts
# Connection timeout (seconds)
curl --connect-timeout 10 https://api.example.com
# Maximum time for entire operation
curl --max-time 30 https://api.example.com
curl -m 30 https://api.example.com
# Keepalive time
curl --keepalive-time 60 https://api.example.com
Retry Logic
# Retry on failure
curl --retry 3 https://api.example.com
# Retry with delay
curl --retry 3 --retry-delay 5 https://api.example.com
# Retry on specific errors
curl --retry 3 --retry-connrefused https://api.example.com
Rate Limiting
# Limit download speed (K = kilobytes, M = megabytes)
curl --limit-rate 100K https://example.com/largefile.zip
# Limit upload speed
curl --limit-rate 50K -T file.zip https://example.com/upload
Proxy
# Use HTTP proxy
curl -x http://proxy.example.com:8080 https://api.example.com
# Use SOCKS5 proxy
curl --socks5 proxy.example.com:1080 https://api.example.com
# Proxy with authentication
curl -x http://user:pass@proxy.example.com:8080 https://api.example.com
# Bypass proxy for specific hosts
curl --noproxy "localhost,127.0.0.1" -x proxy.example.com:8080 https://api.example.com
SSL/TLS Options
# Ignore SSL certificate validation (unsafe - use only for testing)
curl -k https://self-signed.example.com
curl --insecure https://self-signed.example.com
# Specify SSL version
curl --tlsv1.2 https://api.example.com
# Use client certificate
curl --cert client.pem --key key.pem https://api.example.com
# Use CA certificate
curl --cacert ca-bundle.crt https://api.example.com
Response Formatting
Format Output
# Pretty print JSON response (with jq)
curl https://api.example.com/users | jq '.'
# Extract specific field from JSON
curl https://api.example.com/users | jq '.data[].name'
# Silent mode (no progress bar)
curl -s https://api.example.com
# Show only errors
curl -S -s https://api.example.com
# Output format string
curl -w "\nTime: %{time_total}s\nStatus: %{http_code}\n" https://api.example.com
Custom Output Variables
# Show timing information
curl -w "
time_namelookup: %{time_namelookup}
time_connect: %{time_connect}
time_appconnect: %{time_appconnect}
time_pretransfer: %{time_pretransfer}
time_redirect: %{time_redirect}
time_starttransfer: %{time_starttransfer}
time_total: %{time_total}
http_code: %{http_code}
" -o /dev/null -s https://api.example.com
# Save format to file
curl -w "@curl-format.txt" -o /dev/null -s https://api.example.com
Debugging
Verbose Output
# Show detailed request/response
curl -v https://api.example.com
# Even more verbose (includes SSL info)
curl -vv https://api.example.com
# Trace ASCII
curl --trace-ascii debug.txt https://api.example.com
# Trace binary
curl --trace debug.bin https://api.example.com
Testing APIs
# Test API endpoint
curl -I https://api.example.com/health
# Test with timeout
curl -m 5 https://api.example.com
# Check response time
time curl -o /dev/null -s https://api.example.com
# Test with different methods
for method in GET POST PUT DELETE; do
echo "Testing $method:"
curl -X $method -I https://api.example.com/test
done
Common Patterns
API Testing Script
#!/bin/bash
BASE_URL="https://api.example.com"
TOKEN="your_token_here"
# GET request
curl -H "Authorization: Bearer $TOKEN" "$BASE_URL/users"
# POST request
curl -X POST "$BASE_URL/users" \
-H "Authorization: Bearer $TOKEN" \
-H "Content-Type: application/json" \
-d '{"name":"John","email":"john@example.com"}'
# Check status
STATUS=$(curl -o /dev/null -s -w "%{http_code}" "$BASE_URL/health")
if [ "$STATUS" -eq 200 ]; then
echo "API is healthy"
else
echo "API returned status $STATUS"
fi
Download with Progress
# Download with progress bar
curl -# -L -o file.zip https://example.com/download
# Download with custom progress
curl --progress-bar -o file.zip https://example.com/download
REST API CRUD Operations
# Create
curl -X POST https://api.example.com/items \
-H "Content-Type: application/json" \
-d '{"name":"Item1","price":99.99}'
# Read
curl https://api.example.com/items/1
# Update
curl -X PUT https://api.example.com/items/1 \
-H "Content-Type: application/json" \
-d '{"name":"Item1 Updated","price":89.99}'
# Delete
curl -X DELETE https://api.example.com/items/1
Configuration File
Create ~/.curlrc for default options:
# Always follow redirects
-L
# Show error messages
--show-error
# Retry on failure
--retry 3
# Set user agent
user-agent = "MyApp/1.0"
# Always use HTTP/2 if available
--http2
Best Practices
-
Use verbose mode for debugging
curl -v https://api.example.com -
Always handle errors in scripts
if ! curl -f https://api.example.com; then echo "Request failed" exit 1 fi -
Use environment variables for sensitive data
export API_TOKEN="secret" curl -H "Authorization: Bearer $API_TOKEN" https://api.example.com -
Set appropriate timeouts
curl --connect-timeout 10 --max-time 60 https://api.example.com -
Save and reuse cookies for session management
curl -c cookies.txt -d "user=john&pass=secret" https://example.com/login curl -b cookies.txt https://example.com/profile
Common Use Cases
Health Check Monitoring
#!/bin/bash
# Check if service is up
while true; do
STATUS=$(curl -o /dev/null -s -w "%{http_code}" https://api.example.com/health)
if [ "$STATUS" -eq 200 ]; then
echo "$(date): Service is up"
else
echo "$(date): Service returned $STATUS"
fi
sleep 60
done
API Load Testing
# Simple load test
for i in {1..100}; do
curl -o /dev/null -s -w "%{time_total}\n" https://api.example.com &
done
wait
Web Scraping
# Download webpage and extract links
curl -s https://example.com | grep -oP 'href="\K[^"]*'
Testing Webhooks
# Send webhook payload
curl -X POST https://webhook.site/unique-url \
-H "Content-Type: application/json" \
-d '{"event":"user.created","data":{"id":123,"name":"John"}}'
Useful Aliases
# Add to ~/.bashrc or ~/.zshrc
alias curljson='curl -H "Content-Type: application/json"'
alias curlpost='curl -X POST -H "Content-Type: application/json"'
alias curltime='curl -w "\nTotal time: %{time_total}s\n"'
alias curlstatus='curl -o /dev/null -s -w "%{http_code}\n"'
Common Options Reference
| Option | Description |
|---|---|
-X, --request | HTTP method (GET, POST, etc.) |
-H, --header | Custom header |
-d, --data | POST data |
-F, --form | Multipart form data |
-o, --output | Write to file |
-O, --remote-name | Save with remote name |
-L, --location | Follow redirects |
-i, --include | Include headers in output |
-I, --head | Fetch headers only |
-v, --verbose | Verbose output |
-s, --silent | Silent mode |
-u, --user | Username:password |
-b, --cookie | Cookie string or file |
-c, --cookie-jar | Save cookies to file |
-A, --user-agent | User-Agent string |
-e, --referer | Referer URL |
-k, --insecure | Ignore SSL errors |
-x, --proxy | Use proxy |
-m, --max-time | Maximum time in seconds |
--retry | Number of retries |
Troubleshooting
Common Errors
# SSL certificate problem
curl --cacert /path/to/ca-bundle.crt https://example.com
# Connection timeout
curl --connect-timeout 30 https://example.com
# DNS resolution issues
curl --dns-servers 8.8.8.8 https://example.com
# Test specific IP
curl --resolve example.com:443:1.2.3.4 https://example.com
Debug SSL Issues
# Show SSL certificate details
curl -vv https://example.com 2>&1 | grep -A 10 "SSL certificate"
# Test SSL handshake
openssl s_client -connect example.com:443
# Use specific TLS version
curl --tlsv1.2 https://example.com
curl is an incredibly powerful tool for working with APIs, testing endpoints, and automating HTTP requests. Master these patterns and you'll be able to handle almost any HTTP-related task from the command line.
wget
wget is a free command-line utility for non-interactive downloading of files from the web. It supports HTTP, HTTPS, and FTP protocols, and can work through proxies, resume downloads, and handle various network conditions.
Overview
wget is designed for robustness over slow or unstable network connections. If a download fails, it will keep retrying until the whole file has been retrieved. It's ideal for downloading files in scripts and automated tasks.
Key Features:
- Non-interactive operation (works in background)
- Resume interrupted downloads
- Recursive downloads (entire websites)
- Multiple protocol support (HTTP, HTTPS, FTP)
- Proxy support
- Timestamping and mirroring
- Convert links for offline viewing
- Bandwidth limiting
Installation
# Ubuntu/Debian
sudo apt update
sudo apt install wget
# macOS
brew install wget
# CentOS/RHEL
sudo yum install wget
# Arch Linux
sudo pacman -S wget
# Verify installation
wget --version
Basic Usage
Simple Downloads
# Download a file
wget https://example.com/file.zip
# Download and save with different name
wget -O myfile.zip https://example.com/file.zip
wget --output-document=myfile.zip https://example.com/file.zip
# Download to specific directory
wget -P /path/to/directory https://example.com/file.zip
wget --directory-prefix=/path/to/directory https://example.com/file.zip
# Download in background
wget -b https://example.com/largefile.zip
wget --background https://example.com/largefile.zip
# Continue interrupted download
wget -c https://example.com/largefile.zip
wget --continue https://example.com/largefile.zip
Multiple Files
# Download multiple files
wget https://example.com/file1.zip https://example.com/file2.zip
# Download from file list
cat urls.txt
# https://example.com/file1.zip
# https://example.com/file2.zip
# https://example.com/file3.zip
wget -i urls.txt
wget --input-file=urls.txt
# Download from URLs with wildcards
wget https://example.com/file{1..10}.zip
Download Options
# Limit download speed (K, M, G)
wget --limit-rate=200k https://example.com/file.zip
wget --limit-rate=1M https://example.com/file.zip
# Set number of retries
wget --tries=10 https://example.com/file.zip
wget -t 10 https://example.com/file.zip
# Infinite retries
wget --tries=0 https://example.com/file.zip
# Timeout settings
wget --timeout=30 https://example.com/file.zip
wget --dns-timeout=10 --connect-timeout=10 --read-timeout=30 https://example.com/file.zip
# Wait between downloads
wget --wait=5 -i urls.txt # Wait 5 seconds
wget --random-wait -i urls.txt # Random wait 0.5-1.5x wait time
Recursive Downloads
Mirror Websites
# Mirror entire website
wget --mirror --convert-links --page-requisites --no-parent https://example.com
# Shorter version
wget -mkEpnp https://example.com
# Flags explained:
# -m, --mirror: mirror (recursive + timestamping + infinite depth)
# -k, --convert-links: convert links for offline viewing
# -E, --adjust-extension: save HTML with .html extension
# -p, --page-requisites: get all images, CSS, etc.
# -np, --no-parent: don't ascend to parent directory
# Limit recursion depth
wget -r -l 2 https://example.com # 2 levels deep
wget --recursive --level=2 https://example.com
# Download specific file types only
wget -r -A pdf,jpg,png https://example.com
wget --recursive --accept=pdf,jpg,png https://example.com
# Exclude specific file types
wget -r -R gif,svg https://example.com
wget --recursive --reject=gif,svg https://example.com
Download Directories
# Download entire directory
wget -r -np -nH --cut-dirs=2 https://example.com/files/documents/
# Flags explained:
# -r: recursive
# -np: no parent (stay in directory)
# -nH: no host directory
# --cut-dirs=2: skip 2 directory levels
# Example:
# URL: https://example.com/files/documents/pdf/file.pdf
# Without flags: example.com/files/documents/pdf/file.pdf
# With flags: pdf/file.pdf
Authentication
HTTP Authentication
# Basic authentication
wget --user=username --password=password https://example.com/file.zip
# Prompt for password
wget --user=username --ask-password https://example.com/file.zip
# HTTP authentication via .wgetrc
cat << EOF > ~/.wgetrc
http_user = username
http_password = password
EOF
FTP Authentication
# FTP download with credentials
wget ftp://username:password@ftp.example.com/file.zip
# Anonymous FTP
wget ftp://ftp.example.com/file.zip
Cookies
# Send cookies
wget --header="Cookie: session=abc123" https://example.com/file.zip
# Load cookies from file
wget --load-cookies=cookies.txt https://example.com/file.zip
# Save cookies to file
wget --save-cookies=cookies.txt --keep-session-cookies https://example.com/login
# Use cookies for authenticated download
wget --save-cookies=cookies.txt --keep-session-cookies \
--post-data='user=john&pass=secret' \
https://example.com/login
wget --load-cookies=cookies.txt https://example.com/protected/file.zip
Headers and User Agent
Custom Headers
# Set user agent
wget --user-agent="Mozilla/5.0" https://example.com/file.zip
wget -U "Mozilla/5.0" https://example.com/file.zip
# Custom headers
wget --header="Accept: application/json" https://api.example.com/data
wget --header="Authorization: Bearer token123" https://api.example.com/file.zip
# Multiple headers
wget --header="Accept: application/json" \
--header="X-API-Key: abc123" \
https://api.example.com/data
# Referer header
wget --referer=https://example.com https://example.com/file.zip
POST Requests
# POST data
wget --post-data='name=John&email=john@example.com' https://example.com/api
# POST from file
wget --post-file=data.json https://example.com/api
# POST with headers
wget --post-data='{"name":"John"}' \
--header="Content-Type: application/json" \
https://example.com/api
SSL/TLS Options
# Ignore SSL certificate check (unsafe)
wget --no-check-certificate https://self-signed.example.com/file.zip
# Specify CA certificate
wget --ca-certificate=/path/to/ca-cert.pem https://example.com/file.zip
# Use client certificate
wget --certificate=/path/to/client-cert.pem \
--certificate-type=PEM \
https://example.com/file.zip
# Use private key
wget --private-key=/path/to/key.pem https://example.com/file.zip
# Specify SSL protocol
wget --secure-protocol=TLSv1_2 https://example.com/file.zip
Proxy Support
# Use HTTP proxy
wget -e use_proxy=yes -e http_proxy=http://proxy.example.com:8080 https://example.com/file.zip
# Use proxy with authentication
wget -e use_proxy=yes \
-e http_proxy=http://user:pass@proxy.example.com:8080 \
https://example.com/file.zip
# HTTPS proxy
wget -e https_proxy=http://proxy.example.com:8080 https://example.com/file.zip
# FTP proxy
wget -e ftp_proxy=http://proxy.example.com:8080 ftp://ftp.example.com/file.zip
# No proxy for specific domains
wget -e no_proxy=localhost,127.0.0.1 https://example.com/file.zip
# Configure in .wgetrc
cat << EOF > ~/.wgetrc
use_proxy = on
http_proxy = http://proxy.example.com:8080
https_proxy = http://proxy.example.com:8080
ftp_proxy = http://proxy.example.com:8080
no_proxy = localhost,127.0.0.1
EOF
Output Control
Verbosity
# Quiet mode (no output)
wget -q https://example.com/file.zip
wget --quiet https://example.com/file.zip
# Verbose output
wget -v https://example.com/file.zip
wget --verbose https://example.com/file.zip
# Debug output
wget -d https://example.com/file.zip
wget --debug https://example.com/file.zip
# Show progress bar only
wget --progress=bar https://example.com/file.zip
wget --progress=dot https://example.com/file.zip
# No verbose but show errors
wget -nv https://example.com/file.zip
wget --no-verbose https://example.com/file.zip
Logging
# Log to file
wget -o download.log https://example.com/file.zip
wget --output-file=download.log https://example.com/file.zip
# Append to log
wget -a download.log https://example.com/file.zip
wget --append-output=download.log https://example.com/file.zip
# Background download with logging
wget -b -o wget.log https://example.com/largefile.zip
Advanced Features
Timestamping
# Only download if newer than local file
wget -N https://example.com/file.zip
wget --timestamping https://example.com/file.zip
# Check if file has been modified
wget --spider --server-response https://example.com/file.zip
Spider Mode
# Check if file exists without downloading
wget --spider https://example.com/file.zip
# Check if URL is valid
if wget --spider https://example.com/file.zip 2>&1 | grep -q '200 OK'; then
echo "URL is valid"
else
echo "URL is invalid"
fi
# Get response headers only
wget --spider --server-response https://example.com/file.zip
Quota and Limits
# Limit total download size
wget --quota=100M -i urls.txt
# Reject files larger than size
wget --reject-size=10M https://example.com/
# Accept files within size range
wget --accept-size=1M-10M https://example.com/
Filtering
# Include only specific directories
wget -r -I /docs,/guides https://example.com
# Exclude specific directories
wget -r -X /private,/admin https://example.com
# Include only specific domains
wget -r -D example.com,cdn.example.com https://example.com
# Follow only relative links
wget -r --relative https://example.com
Configuration File
.wgetrc
# Create ~/.wgetrc
cat << 'EOF' > ~/.wgetrc
# Retry settings
tries = 10
retry_connrefused = on
# Timeout settings
timeout = 30
dns_timeout = 10
connect_timeout = 10
read_timeout = 30
# Wait between downloads
wait = 2
random_wait = on
# Download settings
continue = on
timestamping = on
# User agent
user_agent = Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36
# Proxy settings
# use_proxy = on
# http_proxy = http://proxy.example.com:8080
# https_proxy = http://proxy.example.com:8080
# Directories
dir_prefix = ~/Downloads/
# Output
verbose = off
quiet = off
EOF
Common Use Cases
Download Large Files
# Download with resume support
wget -c -t 0 --timeout=120 https://example.com/largefile.iso
# Download in background with logging
wget -b -c -o download.log https://example.com/largefile.iso
# Monitor background download
tail -f download.log
Backup Website
#!/bin/bash
# backup-website.sh
SITE="https://example.com"
BACKUP_DIR="/backup/website"
DATE=$(date +%Y%m%d)
mkdir -p "$BACKUP_DIR/$DATE"
cd "$BACKUP_DIR/$DATE"
wget --mirror \
--convert-links \
--adjust-extension \
--page-requisites \
--no-parent \
--no-clobber \
--wait=1 \
--random-wait \
"$SITE"
echo "Backup completed: $BACKUP_DIR/$DATE"
Download All PDFs from Site
# Download all PDFs
wget -r -A pdf https://example.com
# Download PDFs from specific directory
wget -r -np -nd -A pdf https://example.com/documents/
# Download PDFs with original structure
wget -r -np -A pdf https://example.com/documents/
API File Downloads
# Download with authentication token
wget --header="Authorization: Bearer $API_TOKEN" \
https://api.example.com/files/report.pdf
# Download with API key
wget --header="X-API-Key: $API_KEY" \
https://api.example.com/download/file.zip
Batch Downloads
# Create URL list
for i in {1..100}; do
echo "https://example.com/images/img${i}.jpg"
done > urls.txt
# Download with rate limiting
wget -i urls.txt --wait=1 --random-wait --limit-rate=500k
# Download with progress tracking
wget -i urls.txt -o download.log &
tail -f download.log | grep -E "saved|failed"
Scripting Examples
Download with Retry Logic
#!/bin/bash
URL="https://example.com/file.zip"
OUTPUT="file.zip"
MAX_ATTEMPTS=5
for i in $(seq 1 $MAX_ATTEMPTS); do
echo "Attempt $i of $MAX_ATTEMPTS"
if wget -c -O "$OUTPUT" "$URL"; then
echo "Download successful"
exit 0
else
echo "Download failed, retrying..."
sleep 5
fi
done
echo "Download failed after $MAX_ATTEMPTS attempts"
exit 1
Parallel Downloads
#!/bin/bash
# Download multiple files in parallel
URLS=(
"https://example.com/file1.zip"
"https://example.com/file2.zip"
"https://example.com/file3.zip"
)
for url in "${URLS[@]}"; do
wget -c "$url" &
done
# Wait for all downloads to complete
wait
echo "All downloads completed"
Monitor Website Changes
#!/bin/bash
# Check if website has been updated
URL="https://example.com/news.html"
OUTPUT="/tmp/news.html"
if [ -f "$OUTPUT" ]; then
wget -N -o /tmp/wget.log "$URL"
if grep -q "not retrieving" /tmp/wget.log; then
echo "No changes detected"
else
echo "Website has been updated"
# Send notification or perform action
fi
else
wget -O "$OUTPUT" "$URL"
echo "Initial download completed"
fi
Best Practices
- Always use resume support for large files:
wget -c - Be respectful with recursive downloads: use
--waitand--random-wait - Set appropriate timeout values for unreliable connections
- Use timestamping to avoid re-downloading unchanged files:
wget -N - Log downloads for troubleshooting:
wget -o logfile - Limit bandwidth if needed:
--limit-rate - Use .wgetrc for common settings
- Check robots.txt:
wget --execute robots=offto override (use responsibly)
Troubleshooting
Common Issues
# SSL certificate verification failed
wget --no-check-certificate https://example.com/file.zip
# Better: Install proper CA certificates
# Connection timeout
wget --timeout=60 --tries=5 https://example.com/file.zip
# 403 Forbidden error
wget --user-agent="Mozilla/5.0" https://example.com/file.zip
# Cannot write to file (permission denied)
sudo wget -P /protected/directory https://example.com/file.zip
# Resume failed download
wget -c https://example.com/file.zip
# Check download status in background
tail -f wget-log
# Verify download integrity
wget https://example.com/file.zip
wget https://example.com/file.zip.sha256
sha256sum -c file.zip.sha256
Debug Issues
# Enable debug output
wget -d https://example.com/file.zip 2>&1 | tee debug.log
# Check DNS resolution
wget --dns-timeout=10 https://example.com/file.zip
# Test connection only
wget --spider --server-response https://example.com/file.zip
# Show headers
wget -S https://example.com/file.zip
Quick Reference
| Option | Description |
|---|---|
-O file | Save as file |
-P dir | Save to directory |
-c | Continue/resume download |
-b | Background download |
-i file | Download URLs from file |
-r | Recursive download |
-l N | Recursion depth |
-A list | Accept file types |
-R list | Reject file types |
-np | No parent directory |
-m | Mirror website |
-k | Convert links |
-p | Page requisites |
-q | Quiet mode |
-v | Verbose mode |
-N | Timestamping |
--limit-rate=N | Limit speed |
--tries=N | Number of retries |
--timeout=N | Timeout seconds |
wget is a versatile tool for reliable file downloads, website mirroring, and automated download tasks, essential for system administrators and developers.
grep
grep is a command-line utility for searching for text in files. It is a powerful tool that can be used to search for text in files, directories, and more.
Commonly Used grep Commands
-
Search for a specific string in a file:
grep "search_string" filename -
Search for a string in multiple files:
grep "search_string" file1 file2 file3 -
Search recursively in directories:
grep -r "search_string" /path/to/directory -
Search for a string ignoring case:
grep -i "search_string" filename -
Search for a whole word:
grep -w "word" filename -
Search for a string and display line numbers:
grep -n "search_string" filename -
Search for a string and display count of matching lines:
grep -c "search_string" filename -
Search for a string and display only matching part:
grep -o "search_string" filename -
Search for lines that do not match the string:
grep -v "search_string" filename -
Search for multiple patterns:
grep -e "pattern1" -e "pattern2" filename -
Search for a string in compressed files:
zgrep "search_string" compressed_file.gz -
Search for a string and display context lines:
grep -C 3 "search_string" filename
These commands cover a variety of common use cases for the grep command, making it a versatile tool for text searching and manipulation.
find
find is a command-line utility for searching for files in directories. It is a powerful tool that can be used to search for files in directories, subdirectories, and more.
Commonly Used find Commands
-
Find files by name:
find /path/to/directory -name "filename" find . -name "*.py" -
Find files by extension:
find /path/to/directory -name "*.ext" -
Find files by type (e.g., directories):
find /path/to/directory -type d -
Find files by size (e.g., files larger than 100MB):
find /path/to/directory -size +100M -
Find files modified in the last 7 days:
find /path/to/directory -mtime -7 -
Find files accessed in the last 7 days:
find /path/to/directory -atime -7 -
Find files and execute a command on them (e.g., delete):
find /path/to/directory -name "*.tmp" -exec rm {} \; -
Find files by permissions (e.g., files with 777 permissions):
find /path/to/directory -perm 777 -
Find empty files and directories:
find /path/to/directory -empty -
Find files by user:
find /path/to/directory -user username -
Find files by group:
find /path/to/directory -group groupname -
Find files excluding a specific path:
find /path/to/directory -path /exclude/path -prune -o -name "*.ext" -print
These commands cover a variety of common use cases for the find command, making it a versatile tool for file searching and manipulation.
FFmpeg
FFmpeg is a complete, cross-platform solution to record, convert, and stream audio and video. It's one of the most powerful multimedia frameworks available, supporting virtually every codec and format.
Overview
FFmpeg is a command-line tool that can handle virtually any multimedia processing task. It consists of several components including ffmpeg (transcoder), ffprobe (media analyzer), and ffplay (media player).
Key Features:
- Convert between virtually all audio/video formats
- Change codecs, bitrates, and quality settings
- Extract audio from video or vice versa
- Resize, crop, rotate, and flip videos
- Apply filters and effects
- Generate thumbnails and screenshots
- Concatenate multiple files
- Stream to various protocols (RTMP, HLS, DASH)
- Hardware acceleration support
- Subtitle handling (extract, embed, burn-in)
Components:
- ffmpeg: Main command-line tool for conversion and processing
- ffprobe: Analyze media files (metadata, streams, format)
- ffplay: Simple media player for testing
- libavcodec: Codec library
- libavformat: Container format library
- libavfilter: Audio/video filtering library
Installation
Linux
# Ubuntu/Debian
sudo apt update
sudo apt install ffmpeg
# Fedora/RHEL
sudo dnf install ffmpeg
# Arch Linux
sudo pacman -S ffmpeg
# Build from source (latest features)
git clone https://git.ffmpeg.org/ffmpeg.git
cd ffmpeg
./configure --enable-gpl --enable-libx264 --enable-libx265
make
sudo make install
macOS
# Using Homebrew
brew install ffmpeg
# With additional codecs
brew install ffmpeg --with-libvpx --with-libvorbis --with-x265
# Check version
ffmpeg -version
Windows
# Using Chocolatey
choco install ffmpeg
# Or download from https://ffmpeg.org/download.html
# Extract and add to PATH
Basic Concepts
Containers vs Codecs
- Container (format): Wrapper that holds audio/video/subtitle streams (e.g., MP4, MKV, AVI)
- Codec: Algorithm for encoding/decoding media (e.g., H.264, AAC, VP9)
Common combinations:
- MP4 container: H.264 video + AAC audio
- MKV container: H.265 video + Opus audio
- WebM container: VP9 video + Vorbis audio
Stream Selection
FFmpeg identifies streams as:
0:v:0- First video stream0:a:0- First audio stream0:s:0- First subtitle stream
Common Codec Identifiers
Video:
libx264- H.264/AVC (widely compatible)libx265- H.265/HEVC (better compression)libvpx-vp9- VP9 (open, good for web)libaom-av1- AV1 (newest, best compression)
Audio:
aac- AAC (standard)libmp3lame- MP3libopus- Opus (best quality/size)libvorbis- Vorbis (open)
Basic Usage
Get Media Information
# Detailed file information
ffprobe input.mp4
# Show only format information
ffprobe -show_format input.mp4
# Show stream information
ffprobe -show_streams input.mp4
# JSON output
ffprobe -print_format json -show_format -show_streams input.mp4
# Get video duration
ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 input.mp4
# Get video resolution
ffprobe -v error -select_streams v:0 -show_entries stream=width,height -of csv=s=x:p=0 input.mp4
# Get video framerate
ffprobe -v error -select_streams v:0 -show_entries stream=r_frame_rate -of default=noprint_wrappers=1:nokey=1 input.mp4
# Get bitrate
ffprobe -v error -show_entries format=bit_rate -of default=noprint_wrappers=1:nokey=1 input.mp4
Simple Conversion
# Basic format conversion (auto-detect codecs)
ffmpeg -i input.avi output.mp4
# Convert with progress
ffmpeg -i input.avi -progress - output.mp4
# Overwrite output without prompt
ffmpeg -y -i input.avi output.mp4
# Never overwrite
ffmpeg -n -i input.avi output.mp4
Video Conversion
Format Conversion
# AVI to MP4
ffmpeg -i input.avi output.mp4
# MKV to MP4
ffmpeg -i input.mkv -c copy output.mp4 # Copy streams (fast)
# MOV to MP4
ffmpeg -i input.mov -c:v libx264 -c:a aac output.mp4
# WebM to MP4
ffmpeg -i input.webm -c:v libx264 -c:a aac output.mp4
# FLV to MP4
ffmpeg -i input.flv -c:v libx264 -c:a aac output.mp4
# MP4 to WebM
ffmpeg -i input.mp4 -c:v libvpx-vp9 -c:a libopus output.webm
# Any format to GIF
ffmpeg -i input.mp4 -vf "fps=10,scale=320:-1:flags=lanczos" output.gif
Stream Copying (Fast)
# Copy all streams without re-encoding
ffmpeg -i input.mp4 -c copy output.mkv
# Copy video, re-encode audio
ffmpeg -i input.mp4 -c:v copy -c:a aac output.mp4
# Copy audio, re-encode video
ffmpeg -i input.mp4 -c:v libx264 -c:a copy output.mp4
Video Encoding
H.264 Encoding
# Basic H.264 encoding
ffmpeg -i input.mp4 -c:v libx264 -c:a aac output.mp4
# High quality H.264
ffmpeg -i input.mp4 -c:v libx264 -preset slow -crf 18 -c:a aac -b:a 192k output.mp4
# Web-optimized H.264
ffmpeg -i input.mp4 -c:v libx264 -preset fast -crf 22 -c:a aac -b:a 128k -movflags +faststart output.mp4
# Specific bitrate
ffmpeg -i input.mp4 -c:v libx264 -b:v 2M -c:a aac -b:a 128k output.mp4
# Two-pass encoding (better quality)
ffmpeg -i input.mp4 -c:v libx264 -b:v 2M -pass 1 -f mp4 /dev/null
ffmpeg -i input.mp4 -c:v libx264 -b:v 2M -pass 2 output.mp4
# Presets (speed vs compression)
# ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow
ffmpeg -i input.mp4 -c:v libx264 -preset slow -crf 20 output.mp4
# Profiles and levels
ffmpeg -i input.mp4 -c:v libx264 -profile:v baseline -level 3.0 output.mp4
ffmpeg -i input.mp4 -c:v libx264 -profile:v main -level 4.0 output.mp4
ffmpeg -i input.mp4 -c:v libx264 -profile:v high -level 4.2 output.mp4
H.265/HEVC Encoding
# Basic H.265 encoding
ffmpeg -i input.mp4 -c:v libx265 -c:a aac output.mp4
# High quality H.265
ffmpeg -i input.mp4 -c:v libx265 -preset slow -crf 22 -c:a aac output.mp4
# 4K H.265
ffmpeg -i input.mp4 -c:v libx265 -preset medium -crf 24 -c:a aac -tag:v hvc1 output.mp4
# H.265 with specific bitrate
ffmpeg -i input.mp4 -c:v libx265 -b:v 1.5M -c:a aac output.mp4
VP9 Encoding (WebM)
# Basic VP9
ffmpeg -i input.mp4 -c:v libvpx-vp9 -c:a libopus output.webm
# High quality VP9
ffmpeg -i input.mp4 -c:v libvpx-vp9 -crf 30 -b:v 0 -c:a libopus output.webm
# VP9 two-pass
ffmpeg -i input.mp4 -c:v libvpx-vp9 -b:v 1M -pass 1 -f webm /dev/null
ffmpeg -i input.mp4 -c:v libvpx-vp9 -b:v 1M -pass 2 -c:a libopus output.webm
# VP9 with quality settings
ffmpeg -i input.mp4 -c:v libvpx-vp9 -crf 30 -b:v 0 -row-mt 1 -c:a libopus -b:a 128k output.webm
AV1 Encoding
# Basic AV1 (slow but best compression)
ffmpeg -i input.mp4 -c:v libaom-av1 -crf 30 -c:a libopus output.webm
# AV1 with speed settings
ffmpeg -i input.mp4 -c:v libaom-av1 -cpu-used 4 -crf 30 output.webm
# SVT-AV1 (faster)
ffmpeg -i input.mp4 -c:v libsvtav1 -crf 35 -c:a libopus output.webm
Quality Control
# CRF (Constant Rate Factor) - recommended
# Lower = better quality, larger file
# H.264: 18-28 (23 default)
# H.265: 22-32 (28 default)
# VP9: 15-35 (30 default)
ffmpeg -i input.mp4 -c:v libx264 -crf 23 output.mp4
# CBR (Constant Bitrate)
ffmpeg -i input.mp4 -c:v libx264 -b:v 2M -minrate 2M -maxrate 2M -bufsize 1M output.mp4
# VBR (Variable Bitrate)
ffmpeg -i input.mp4 -c:v libx264 -b:v 2M -maxrate 3M -bufsize 2M output.mp4
# Target file size
# Calculate bitrate: (target_size_MB * 8192) / duration_seconds
ffmpeg -i input.mp4 -c:v libx264 -b:v 1500k -pass 1 -f mp4 /dev/null
ffmpeg -i input.mp4 -c:v libx264 -b:v 1500k -pass 2 output.mp4
Audio Operations
Audio Extraction
# Extract audio to MP3
ffmpeg -i input.mp4 -vn -c:a libmp3lame -b:a 192k output.mp3
# Extract audio to AAC
ffmpeg -i input.mp4 -vn -c:a aac -b:a 192k output.aac
# Extract audio to FLAC (lossless)
ffmpeg -i input.mp4 -vn -c:a flac output.flac
# Extract audio without re-encoding
ffmpeg -i input.mp4 -vn -c:a copy output.aac
Audio Conversion
# Convert audio format
ffmpeg -i input.mp3 output.wav
ffmpeg -i input.wav -c:a libmp3lame -b:a 320k output.mp3
ffmpeg -i input.mp3 -c:a aac -b:a 192k output.aac
ffmpeg -i input.wav -c:a libopus -b:a 128k output.opus
# Change sample rate
ffmpeg -i input.mp3 -ar 44100 output.mp3
ffmpeg -i input.wav -ar 48000 output.wav
# Change channels (mono/stereo)
ffmpeg -i input.mp3 -ac 1 output.mp3 # Mono
ffmpeg -i input.mp3 -ac 2 output.mp3 # Stereo
# Normalize audio
ffmpeg -i input.mp3 -af "loudnorm" output.mp3
# Change volume
ffmpeg -i input.mp3 -af "volume=2.0" output.mp3 # Double volume
ffmpeg -i input.mp3 -af "volume=0.5" output.mp3 # Half volume
ffmpeg -i input.mp3 -af "volume=10dB" output.mp3 # Increase by 10dB
Audio Bitrate
# Constant bitrate
ffmpeg -i input.mp4 -c:a aac -b:a 128k output.mp4
# Common bitrates
ffmpeg -i input.mp4 -c:a aac -b:a 96k output.mp4 # Low quality
ffmpeg -i input.mp4 -c:a aac -b:a 128k output.mp4 # Standard
ffmpeg -i input.mp4 -c:a aac -b:a 192k output.mp4 # Good quality
ffmpeg -i input.mp4 -c:a aac -b:a 256k output.mp4 # High quality
ffmpeg -i input.mp4 -c:a aac -b:a 320k output.mp4 # Maximum quality
Merge Audio and Video
# Replace audio in video
ffmpeg -i video.mp4 -i audio.mp3 -c:v copy -c:a aac -map 0:v:0 -map 1:a:0 output.mp4
# Add audio track (multiple audio streams)
ffmpeg -i video.mp4 -i audio.mp3 -c copy -map 0 -map 1:a output.mp4
# Mix two audio tracks
ffmpeg -i input1.mp3 -i input2.mp3 -filter_complex "[0:a][1:a]amix=inputs=2:duration=longest" output.mp3
Video Filters
Resize and Scale
# Resize to specific dimensions
ffmpeg -i input.mp4 -vf "scale=1280:720" output.mp4
# Resize maintaining aspect ratio
ffmpeg -i input.mp4 -vf "scale=1280:-1" output.mp4 # Width 1280, auto height
ffmpeg -i input.mp4 -vf "scale=-1:720" output.mp4 # Height 720, auto width
# Scale to percentage
ffmpeg -i input.mp4 -vf "scale=iw*0.5:ih*0.5" output.mp4 # 50% size
# Common resolutions
ffmpeg -i input.mp4 -vf "scale=1920:1080" output.mp4 # 1080p
ffmpeg -i input.mp4 -vf "scale=1280:720" output.mp4 # 720p
ffmpeg -i input.mp4 -vf "scale=854:480" output.mp4 # 480p
ffmpeg -i input.mp4 -vf "scale=640:360" output.mp4 # 360p
# High quality scaling
ffmpeg -i input.mp4 -vf "scale=1920:1080:flags=lanczos" output.mp4
Crop
# Crop to specific size
# crop=width:height:x:y
ffmpeg -i input.mp4 -vf "crop=1280:720:0:0" output.mp4
# Crop center
ffmpeg -i input.mp4 -vf "crop=1920:800:0:140" output.mp4
# Crop to 16:9 from 4:3
ffmpeg -i input.mp4 -vf "crop=in_h*16/9:in_h" output.mp4
# Auto-detect crop
ffmpeg -i input.mp4 -vf "cropdetect" -f null -
# Then use detected values
ffmpeg -i input.mp4 -vf "crop=1920:800:0:140" output.mp4
# Crop and scale
ffmpeg -i input.mp4 -vf "crop=1920:800:0:140,scale=1280:534" output.mp4
Rotate and Flip
# Rotate 90 degrees clockwise
ffmpeg -i input.mp4 -vf "transpose=1" output.mp4
# Rotate 90 degrees counter-clockwise
ffmpeg -i input.mp4 -vf "transpose=2" output.mp4
# Rotate 180 degrees
ffmpeg -i input.mp4 -vf "transpose=2,transpose=2" output.mp4
# Flip horizontal
ffmpeg -i input.mp4 -vf "hflip" output.mp4
# Flip vertical
ffmpeg -i input.mp4 -vf "vflip" output.mp4
# Rotate by arbitrary angle
ffmpeg -i input.mp4 -vf "rotate=45*PI/180" output.mp4
Watermark
# Add image watermark
ffmpeg -i input.mp4 -i logo.png -filter_complex "overlay=10:10" output.mp4
# Watermark in bottom right
ffmpeg -i input.mp4 -i logo.png -filter_complex "overlay=W-w-10:H-h-10" output.mp4
# Watermark centered
ffmpeg -i input.mp4 -i logo.png -filter_complex "overlay=(W-w)/2:(H-h)/2" output.mp4
# Transparent watermark
ffmpeg -i input.mp4 -i logo.png -filter_complex "[1:v]format=rgba,colorchannelmixer=aa=0.5[logo];[0:v][logo]overlay=10:10" output.mp4
# Text watermark
ffmpeg -i input.mp4 -vf "drawtext=text='Copyright 2024':x=10:y=10:fontsize=24:fontcolor=white" output.mp4
# Text with shadow
ffmpeg -i input.mp4 -vf "drawtext=text='Copyright':x=10:y=10:fontsize=36:fontcolor=white:shadowcolor=black:shadowx=2:shadowy=2" output.mp4
# Dynamic timestamp
ffmpeg -i input.mp4 -vf "drawtext=text='%{localtime\:%Y-%m-%d %H\\:%M\\:%S}':x=10:y=10:fontsize=24:fontcolor=white" output.mp4
Fade In/Out
# Fade in video (first 2 seconds)
ffmpeg -i input.mp4 -vf "fade=in:0:60" output.mp4
# Fade out video (last 2 seconds)
ffmpeg -i input.mp4 -vf "fade=out:st=28:d=2" output.mp4
# Fade in and out
ffmpeg -i input.mp4 -vf "fade=in:0:60,fade=out:st=28:d=2" output.mp4
# Audio fade in/out
ffmpeg -i input.mp4 -af "afade=in:st=0:d=2,afade=out:st=28:d=2" output.mp4
# Combined video and audio fade
ffmpeg -i input.mp4 -vf "fade=in:0:60,fade=out:st=28:d=60" -af "afade=in:st=0:d=2,afade=out:st=28:d=2" output.mp4
Color Adjustments
# Brightness
ffmpeg -i input.mp4 -vf "eq=brightness=0.1" output.mp4
# Contrast
ffmpeg -i input.mp4 -vf "eq=contrast=1.5" output.mp4
# Saturation
ffmpeg -i input.mp4 -vf "eq=saturation=1.5" output.mp4
# Gamma
ffmpeg -i input.mp4 -vf "eq=gamma=1.2" output.mp4
# Combined adjustments
ffmpeg -i input.mp4 -vf "eq=brightness=0.1:contrast=1.2:saturation=1.3" output.mp4
# Grayscale
ffmpeg -i input.mp4 -vf "hue=s=0" output.mp4
# Sepia tone
ffmpeg -i input.mp4 -vf "colorchannelmixer=.393:.769:.189:0:.349:.686:.168:0:.272:.534:.131" output.mp4
Blur and Sharpen
# Blur
ffmpeg -i input.mp4 -vf "boxblur=5:1" output.mp4
# Gaussian blur
ffmpeg -i input.mp4 -vf "gblur=sigma=5" output.mp4
# Sharpen
ffmpeg -i input.mp4 -vf "unsharp=5:5:1.5:5:5:0.0" output.mp4
# Denoise
ffmpeg -i input.mp4 -vf "nlmeans" output.mp4
Advanced Filters
Complex Filter Chains
# Scale and crop
ffmpeg -i input.mp4 -vf "scale=1920:1080,crop=1920:800:0:140" output.mp4
# Multiple filters
ffmpeg -i input.mp4 -vf "scale=1280:720,hue=s=1.5,eq=brightness=0.1" output.mp4
# Filter with audio
ffmpeg -i input.mp4 -vf "scale=1280:720" -af "volume=2.0" output.mp4
Picture-in-Picture
# Basic PIP
ffmpeg -i main.mp4 -i overlay.mp4 -filter_complex \
"[1:v]scale=320:240[pip];[0:v][pip]overlay=W-w-10:H-h-10" \
output.mp4
# PIP with different positions
# Top-left
ffmpeg -i main.mp4 -i overlay.mp4 -filter_complex \
"[1:v]scale=320:240[pip];[0:v][pip]overlay=10:10" output.mp4
# Top-right
ffmpeg -i main.mp4 -i overlay.mp4 -filter_complex \
"[1:v]scale=320:240[pip];[0:v][pip]overlay=W-w-10:10" output.mp4
# Bottom-left
ffmpeg -i main.mp4 -i overlay.mp4 -filter_complex \
"[1:v]scale=320:240[pip];[0:v][pip]overlay=10:H-h-10" output.mp4
Side-by-Side
# Side-by-side comparison
ffmpeg -i left.mp4 -i right.mp4 -filter_complex \
"[0:v][1:v]hstack=inputs=2" output.mp4
# Vertical stack
ffmpeg -i top.mp4 -i bottom.mp4 -filter_complex \
"[0:v][1:v]vstack=inputs=2" output.mp4
# 2x2 grid
ffmpeg -i input1.mp4 -i input2.mp4 -i input3.mp4 -i input4.mp4 \
-filter_complex \
"[0:v][1:v]hstack[top];[2:v][3:v]hstack[bottom];[top][bottom]vstack" \
output.mp4
Speed Changes
# Speed up video (2x)
ffmpeg -i input.mp4 -vf "setpts=0.5*PTS" output.mp4
# Slow down video (0.5x)
ffmpeg -i input.mp4 -vf "setpts=2.0*PTS" output.mp4
# Speed up audio
ffmpeg -i input.mp4 -filter:a "atempo=2.0" output.mp4
# Speed up both video and audio (2x)
ffmpeg -i input.mp4 -vf "setpts=0.5*PTS" -af "atempo=2.0" output.mp4
# Slow motion (0.5x) with audio
ffmpeg -i input.mp4 -vf "setpts=2.0*PTS" -af "atempo=0.5" output.mp4
# Speed limits: atempo must be between 0.5 and 2.0
# For 4x speed, chain multiple atempo filters
ffmpeg -i input.mp4 -filter:a "atempo=2.0,atempo=2.0" output.mp4
Framerate Changes
# Change framerate
ffmpeg -i input.mp4 -r 30 output.mp4 # 30 fps
ffmpeg -i input.mp4 -r 60 output.mp4 # 60 fps
# Convert to 24fps (film)
ffmpeg -i input.mp4 -r 24 output.mp4
# Duplicate frames to increase fps
ffmpeg -i input.mp4 -vf "fps=60" output.mp4
# Interpolate frames (smooth)
ffmpeg -i input.mp4 -vf "minterpolate=fps=60:mi_mode=mci" output.mp4
Screenshots and Thumbnails
Extract Single Frame
# Extract first frame
ffmpeg -i input.mp4 -vf "select=eq(n\,0)" -q:v 1 -frames:v 1 output.png
# Extract frame at specific time
ffmpeg -ss 00:00:10 -i input.mp4 -frames:v 1 output.jpg
# Extract frame at 5 seconds
ffmpeg -ss 5 -i input.mp4 -frames:v 1 output.png
# High quality screenshot
ffmpeg -ss 00:01:30 -i input.mp4 -frames:v 1 -q:v 2 output.jpg
# Specific size screenshot
ffmpeg -ss 10 -i input.mp4 -vf "scale=1920:1080" -frames:v 1 output.png
Extract Multiple Frames
# Extract every frame
ffmpeg -i input.mp4 frame_%04d.png
# Extract 1 frame per second
ffmpeg -i input.mp4 -vf "fps=1" frame_%04d.png
# Extract 1 frame every 10 seconds
ffmpeg -i input.mp4 -vf "fps=1/10" frame_%04d.png
# Extract frames from specific time range
ffmpeg -ss 00:00:10 -t 00:00:05 -i input.mp4 -vf "fps=1" frame_%04d.png
# Extract frames with specific quality
ffmpeg -i input.mp4 -vf "fps=1" -q:v 2 frame_%04d.jpg
Create Thumbnails
# Create thumbnail grid (contact sheet)
ffmpeg -i input.mp4 -vf "fps=1/60,scale=320:240,tile=4x3" thumbnail.png
# Create thumbnail at specific interval
ffmpeg -i input.mp4 -vf "thumbnail=300" -frames:v 1 thumb.png
# Create multiple thumbnails
ffmpeg -i input.mp4 -vf "fps=1/60" thumb_%03d.jpg
Create GIF
# Basic GIF
ffmpeg -i input.mp4 output.gif
# High quality GIF
ffmpeg -i input.mp4 -vf "fps=10,scale=320:-1:flags=lanczos,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" output.gif
# GIF from specific time range
ffmpeg -ss 5 -t 10 -i input.mp4 -vf "fps=10,scale=480:-1:flags=lanczos" output.gif
# Optimized GIF with custom palette
ffmpeg -i input.mp4 -vf "fps=15,scale=480:-1:flags=lanczos,palettegen" palette.png
ffmpeg -i input.mp4 -i palette.png -filter_complex "fps=15,scale=480:-1:flags=lanczos[x];[x][1:v]paletteuse" output.gif
Concatenation and Trimming
Trim/Cut Video
# Cut from start time for duration
ffmpeg -ss 00:00:10 -t 00:00:30 -i input.mp4 -c copy output.mp4
# Cut from start to end time
ffmpeg -ss 00:00:10 -to 00:00:40 -i input.mp4 -c copy output.mp4
# Cut with re-encoding (more precise)
ffmpeg -i input.mp4 -ss 00:00:10 -t 00:00:30 -c:v libx264 -c:a aac output.mp4
# Multiple segments
ffmpeg -i input.mp4 -ss 00:00:00 -t 00:00:10 part1.mp4
ffmpeg -i input.mp4 -ss 00:00:10 -t 00:00:10 part2.mp4
ffmpeg -i input.mp4 -ss 00:00:20 -t 00:00:10 part3.mp4
# Cut last N seconds
ffmpeg -sseof -10 -i input.mp4 -c copy last_10sec.mp4
Concatenate Videos
# Method 1: Concat demuxer (same codec, fast)
# Create file list
echo "file 'video1.mp4'" > filelist.txt
echo "file 'video2.mp4'" >> filelist.txt
echo "file 'video3.mp4'" >> filelist.txt
ffmpeg -f concat -safe 0 -i filelist.txt -c copy output.mp4
# Method 2: Concat filter (different codecs)
ffmpeg -i video1.mp4 -i video2.mp4 -i video3.mp4 \
-filter_complex "[0:v][0:a][1:v][1:a][2:v][2:a]concat=n=3:v=1:a=1[outv][outa]" \
-map "[outv]" -map "[outa]" output.mp4
# Method 3: Concat protocol (identical files)
ffmpeg -i "concat:video1.mp4|video2.mp4|video3.mp4" -c copy output.mp4
# Concatenate with transition
ffmpeg -i input1.mp4 -i input2.mp4 -filter_complex \
"[0:v]fade=out:st=9:d=1[v0];[1:v]fade=in:st=0:d=1[v1];[v0][v1]concat=n=2:v=1:a=0" \
output.mp4
Split Video
# Split into equal parts
ffmpeg -i input.mp4 -c copy -map 0 -segment_time 300 -f segment output%03d.mp4
# Split by size
ffmpeg -i input.mp4 -c copy -map 0 -segment_size 100M -f segment output%03d.mp4
# Split at keyframes
ffmpeg -i input.mp4 -c copy -segment_time 300 -reset_timestamps 1 -f segment output%03d.mp4
Streaming
HLS (HTTP Live Streaming)
# Basic HLS
ffmpeg -i input.mp4 -hls_time 10 -hls_list_size 0 -f hls output.m3u8
# HLS with different quality levels (adaptive streaming)
ffmpeg -i input.mp4 \
-vf "scale=1280:720" -c:v libx264 -b:v 2M -c:a aac -b:a 128k -hls_time 10 720p.m3u8 \
-vf "scale=854:480" -c:v libx264 -b:v 1M -c:a aac -b:a 96k -hls_time 10 480p.m3u8 \
-vf "scale=640:360" -c:v libx264 -b:v 500k -c:a aac -b:a 64k -hls_time 10 360p.m3u8
# HLS with segment naming
ffmpeg -i input.mp4 \
-hls_time 10 \
-hls_list_size 0 \
-hls_segment_filename "segment_%03d.ts" \
-f hls output.m3u8
# HLS with encryption
ffmpeg -i input.mp4 \
-hls_time 10 \
-hls_key_info_file key_info.txt \
-hls_list_size 0 \
-f hls output.m3u8
# HLS options
ffmpeg -i input.mp4 \
-c:v libx264 -c:a aac \
-hls_time 6 \ # Segment duration
-hls_list_size 0 \ # Keep all segments in playlist
-hls_segment_type mpegts \ # Segment format
-hls_flags delete_segments \ # Delete old segments
-hls_start_number_source datetime \
-f hls output.m3u8
DASH (Dynamic Adaptive Streaming over HTTP)
# Basic DASH
ffmpeg -i input.mp4 -c:v libx264 -c:a aac -f dash output.mpd
# DASH with multiple qualities
ffmpeg -i input.mp4 \
-map 0:v -map 0:a -c:v libx264 -c:a aac \
-b:v:0 2M -s:v:0 1280x720 \
-b:v:1 1M -s:v:1 854x480 \
-b:v:2 500k -s:v:3 640x360 \
-adaptation_sets "id=0,streams=v id=1,streams=a" \
-f dash output.mpd
RTMP Streaming
# Stream to RTMP server
ffmpeg -re -i input.mp4 -c:v libx264 -preset veryfast -maxrate 3M \
-bufsize 6M -c:a aac -b:a 128k -f flv rtmp://server/live/stream
# Stream with specific resolution and framerate
ffmpeg -re -i input.mp4 \
-vf "scale=1280:720" -r 30 \
-c:v libx264 -preset veryfast -b:v 2M \
-c:a aac -b:a 128k \
-f flv rtmp://server/live/stream
# Stream from webcam
ffmpeg -f v4l2 -i /dev/video0 -f alsa -i default \
-c:v libx264 -preset veryfast -b:v 1M \
-c:a aac -b:a 128k \
-f flv rtmp://server/live/stream
# Re-stream (relay)
ffmpeg -i rtmp://source/live/stream -c copy -f flv rtmp://destination/live/stream
UDP/RTP Streaming
# UDP streaming
ffmpeg -re -i input.mp4 -c:v libx264 -c:a aac -f mpegts udp://192.168.1.100:1234
# RTP streaming
ffmpeg -re -i input.mp4 -c:v libx264 -c:a aac -f rtp rtp://192.168.1.100:1234
# SRT streaming
ffmpeg -re -i input.mp4 -c:v libx264 -c:a aac -f mpegts srt://192.168.1.100:1234
Subtitles
Extract Subtitles
# Extract all subtitle tracks
ffmpeg -i input.mkv -c:s copy subtitles.srt
# Extract specific subtitle
ffmpeg -i input.mkv -map 0:s:0 -c:s copy subtitle_track1.srt
# Convert subtitle format
ffmpeg -i input.srt output.ass
ffmpeg -i input.ass output.srt
Add Subtitles
# Soft subtitles (embedded, can be toggled)
ffmpeg -i input.mp4 -i subtitles.srt -c copy -c:s mov_text output.mp4
# Add multiple subtitle tracks
ffmpeg -i input.mp4 -i eng.srt -i spa.srt \
-c copy -c:s mov_text \
-metadata:s:s:0 language=eng \
-metadata:s:s:1 language=spa \
output.mp4
# Hard subtitles (burned in, always visible)
ffmpeg -i input.mp4 -vf "subtitles=subtitles.srt" output.mp4
# Burn subtitles with style
ffmpeg -i input.mp4 -vf "subtitles=subtitles.srt:force_style='FontName=Arial,FontSize=24,PrimaryColour=&H00FFFF'" output.mp4
# Burn ASS/SSA subtitles
ffmpeg -i input.mp4 -vf "ass=subtitles.ass" output.mp4
Create Subtitles
# Generate subtitle from text file
# Create subtitle.srt:
# 1
# 00:00:00,000 --> 00:00:05,000
# First subtitle text
#
# 2
# 00:00:05,000 --> 00:00:10,000
# Second subtitle text
ffmpeg -i input.mp4 -i subtitle.srt -c copy -c:s mov_text output.mp4
Metadata
View Metadata
# Show all metadata
ffprobe -show_format -show_streams input.mp4
# Show only metadata
ffmpeg -i input.mp4 -f ffmetadata metadata.txt
# Extract cover art
ffmpeg -i input.mp3 -an -vcodec copy cover.jpg
Edit Metadata
# Set metadata tags
ffmpeg -i input.mp4 -metadata title="My Video" \
-metadata author="John Doe" \
-metadata copyright="2024" \
-c copy output.mp4
# Remove all metadata
ffmpeg -i input.mp4 -map_metadata -1 -c copy output.mp4
# Add cover art to audio
ffmpeg -i input.mp3 -i cover.jpg \
-map 0:a -map 1:v \
-c:a copy -c:v copy \
-metadata:s:v title="Album cover" \
-metadata:s:v comment="Cover (front)" \
output.mp3
# Copy metadata from one file to another
ffmpeg -i source.mp4 -i destination.mp4 -map 1 -map_metadata 0 -c copy output.mp4
Performance and Hardware Acceleration
Hardware Encoding
# NVIDIA NVENC (H.264)
ffmpeg -i input.mp4 -c:v h264_nvenc -preset slow -b:v 2M output.mp4
# NVIDIA NVENC (H.265)
ffmpeg -i input.mp4 -c:v hevc_nvenc -preset slow -b:v 2M output.mp4
# Intel Quick Sync (H.264)
ffmpeg -i input.mp4 -c:v h264_qsv -preset slow -b:v 2M output.mp4
# Intel Quick Sync (H.265)
ffmpeg -i input.mp4 -c:v hevc_qsv -preset slow -b:v 2M output.mp4
# AMD VCE (H.264)
ffmpeg -i input.mp4 -c:v h264_amf -b:v 2M output.mp4
# Apple VideoToolbox (H.264)
ffmpeg -i input.mp4 -c:v h264_videotoolbox -b:v 2M output.mp4
# VA-API (Linux)
ffmpeg -vaapi_device /dev/dri/renderD128 -i input.mp4 \
-vf 'format=nv12,hwupload' -c:v h264_vaapi -b:v 2M output.mp4
Hardware Decoding
# NVIDIA CUDA decoding + NVENC encoding
ffmpeg -hwaccel cuda -i input.mp4 -c:v h264_nvenc -preset slow output.mp4
# Intel Quick Sync decoding + encoding
ffmpeg -hwaccel qsv -c:v h264_qsv -i input.mp4 -c:v h264_qsv output.mp4
# VA-API decoding + encoding
ffmpeg -hwaccel vaapi -hwaccel_device /dev/dri/renderD128 -i input.mp4 \
-vf 'format=nv12,hwupload' -c:v h264_vaapi output.mp4
Performance Options
# Multi-threading
ffmpeg -threads 4 -i input.mp4 output.mp4
ffmpeg -threads 0 -i input.mp4 output.mp4 # Auto detect
# Faster encoding (lower quality)
ffmpeg -i input.mp4 -preset ultrafast -crf 23 output.mp4
# Quality vs speed (presets)
# ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow
ffmpeg -i input.mp4 -preset medium -crf 23 output.mp4
# Tune for specific content
ffmpeg -i input.mp4 -tune film output.mp4 # Film content
ffmpeg -i input.mp4 -tune animation output.mp4 # Animation
ffmpeg -i input.mp4 -tune grain output.mp4 # Grainy film
ffmpeg -i input.mp4 -tune stillimage output.mp4 # Slideshow
Common Patterns
Web-Optimized Video
# HTML5 video (MP4)
ffmpeg -i input.mp4 \
-c:v libx264 -preset slow -crf 22 \
-c:a aac -b:a 128k \
-movflags +faststart \
-vf "scale=1280:720" \
output.mp4
# WebM for web
ffmpeg -i input.mp4 \
-c:v libvpx-vp9 -crf 30 -b:v 0 \
-c:a libopus -b:a 128k \
-vf "scale=1280:720" \
output.webm
# Both formats for compatibility
ffmpeg -i input.mp4 -c:v libx264 -preset slow -crf 22 -movflags +faststart video.mp4
ffmpeg -i input.mp4 -c:v libvpx-vp9 -crf 30 -b:v 0 -c:a libopus video.webm
Social Media Formats
# Instagram (1:1 square)
ffmpeg -i input.mp4 \
-vf "scale=1080:1080:force_original_aspect_ratio=decrease,pad=1080:1080:(ow-iw)/2:(oh-ih)/2" \
-c:v libx264 -preset slow -crf 23 \
-c:a aac -b:a 128k \
instagram.mp4
# Instagram Stories (9:16)
ffmpeg -i input.mp4 \
-vf "scale=1080:1920:force_original_aspect_ratio=decrease,pad=1080:1920:(ow-iw)/2:(oh-ih)/2" \
-c:v libx264 -preset slow -crf 23 \
-c:a aac -b:a 128k \
story.mp4
# Twitter (16:9, < 512MB, < 2:20)
ffmpeg -i input.mp4 \
-c:v libx264 -preset slow -crf 23 -maxrate 2M -bufsize 4M \
-vf "scale=1280:720" \
-c:a aac -b:a 128k \
-movflags +faststart \
twitter.mp4
# YouTube (recommended settings)
ffmpeg -i input.mp4 \
-c:v libx264 -preset slow -crf 18 \
-c:a aac -b:a 192k \
-vf "scale=1920:1080" \
-r 30 \
-movflags +faststart \
youtube.mp4
Batch Processing
# Convert all MP4 files to WebM
for f in *.mp4; do
ffmpeg -i "$f" -c:v libvpx-vp9 -crf 30 "${f%.mp4}.webm"
done
# Batch resize
for f in *.mp4; do
ffmpeg -i "$f" -vf "scale=1280:720" "resized_${f}"
done
# Batch extract audio
for f in *.mp4; do
ffmpeg -i "$f" -vn -c:a libmp3lame -b:a 192k "${f%.mp4}.mp3"
done
# Parallel processing with GNU parallel
ls *.mp4 | parallel -j 4 ffmpeg -i {} -c:v libx264 -crf 23 {.}_converted.mp4
Video from Images
# Create video from image sequence
ffmpeg -framerate 30 -pattern_type glob -i "*.jpg" -c:v libx264 -pix_fmt yuv420p output.mp4
# Specific pattern
ffmpeg -framerate 30 -i image_%04d.jpg -c:v libx264 output.mp4
# Slideshow with duration
ffmpeg -loop 1 -t 5 -i image.jpg -c:v libx264 -pix_fmt yuv420p output.mp4
# Slideshow from multiple images
ffmpeg -loop 1 -t 3 -i img1.jpg \
-loop 1 -t 3 -i img2.jpg \
-loop 1 -t 3 -i img3.jpg \
-filter_complex "[0:v][1:v][2:v]concat=n=3:v=1:a=0" \
slideshow.mp4
# Ken Burns effect (zoom and pan)
ffmpeg -loop 1 -i image.jpg \
-vf "zoompan=z='min(zoom+0.0015,1.5)':d=750:x='iw/2-(iw/zoom/2)':y='ih/2-(ih/zoom/2)':s=1920x1080" \
-c:v libx264 -t 30 output.mp4
Screen Recording Conversion
# Optimize screen recording
ffmpeg -i screen_recording.mp4 \
-c:v libx264 -preset slow -crf 18 \
-vf "scale=1920:1080" \
-c:a aac -b:a 128k \
optimized.mp4
# Remove silence from screen recording
ffmpeg -i recording.mp4 \
-af "silenceremove=1:0:-50dB" \
no_silence.mp4
Best Practices
1. Use Two-Pass Encoding for Best Quality
# Pass 1
ffmpeg -i input.mp4 -c:v libx264 -b:v 2M -pass 1 -f mp4 /dev/null
# Pass 2
ffmpeg -i input.mp4 -c:v libx264 -b:v 2M -pass 2 output.mp4
2. Use CRF for Variable Bitrate
# Better quality-to-size ratio
ffmpeg -i input.mp4 -c:v libx264 -crf 23 output.mp4
3. Fast Start for Web Videos
# Move moov atom to beginning (faster streaming start)
ffmpeg -i input.mp4 -c copy -movflags +faststart output.mp4
4. Preserve Quality with Stream Copy
# When changing container only, use -c copy
ffmpeg -i input.mkv -c copy output.mp4
5. Use Proper Pixel Format
# Ensure compatibility (yuv420p for most players)
ffmpeg -i input.mp4 -pix_fmt yuv420p output.mp4
6. Optimize Presets
# Balance quality and encoding time
ffmpeg -i input.mp4 -preset slow -crf 22 output.mp4
7. Check Input First
# Always analyze before processing
ffprobe -show_streams input.mp4
8. Use Appropriate Audio Bitrate
# Don't waste space on audio
ffmpeg -i input.mp4 -c:v libx264 -crf 23 -c:a aac -b:a 128k output.mp4
9. Batch Process Efficiently
# Use shell loops for multiple files
for f in *.mp4; do ffmpeg -i "$f" -c:v libx264 -crf 23 "${f%.mp4}_new.mp4"; done
10. Keep Original Aspect Ratio
# Use -1 to maintain aspect ratio
ffmpeg -i input.mp4 -vf "scale=1280:-1" output.mp4
Troubleshooting
Common Errors
# "Unknown encoder 'libx264'"
# Install ffmpeg with libx264 support
sudo apt install ffmpeg libx264-dev
# "Could not find codec parameters"
# File may be corrupted, try re-encoding
ffmpeg -err_detect ignore_err -i input.mp4 -c:v libx264 output.mp4
# "Invalid data found when processing input"
# Skip invalid data
ffmpeg -i input.mp4 -c copy -bsf:v h264_mp4toannexb output.mp4
# "Output file is empty"
# Check codecs and formats
ffprobe input.mp4
ffmpeg -i input.mp4 -c:v libx264 -c:a aac output.mp4
# "Encoder did not produce proper pts"
# Add -vsync vfr
ffmpeg -i input.mp4 -vsync vfr output.mp4
Audio/Video Sync Issues
# Fix A/V sync
ffmpeg -i input.mp4 -async 1 -vsync 1 output.mp4
# Delay audio by 2 seconds
ffmpeg -i input.mp4 -itsoffset 2 -i input.mp4 -map 0:v -map 1:a -c copy output.mp4
# Advance audio by 2 seconds
ffmpeg -i input.mp4 -itsoffset -2 -i input.mp4 -map 0:v -map 1:a -c copy output.mp4
Quality Issues
# Improve quality (lower CRF)
ffmpeg -i input.mp4 -c:v libx264 -crf 18 output.mp4
# Two-pass for better quality
ffmpeg -i input.mp4 -c:v libx264 -b:v 2M -pass 1 -f mp4 /dev/null
ffmpeg -i input.mp4 -c:v libx264 -b:v 2M -pass 2 output.mp4
# Use better preset
ffmpeg -i input.mp4 -preset slower -crf 20 output.mp4
Performance Issues
# Use hardware acceleration
ffmpeg -hwaccel cuda -i input.mp4 -c:v h264_nvenc output.mp4
# Use faster preset
ffmpeg -i input.mp4 -preset ultrafast output.mp4
# Limit CPU usage
ffmpeg -threads 2 -i input.mp4 output.mp4
File Size Issues
# Reduce file size (increase CRF)
ffmpeg -i input.mp4 -c:v libx264 -crf 28 output.mp4
# Target specific file size (calculate bitrate)
# bitrate = (target_size_MB * 8192) / duration_seconds - audio_bitrate
ffmpeg -i input.mp4 -b:v 1000k -c:a aac -b:a 128k output.mp4
# Two-pass for exact size
ffmpeg -i input.mp4 -b:v 1000k -pass 1 -f mp4 /dev/null
ffmpeg -i input.mp4 -b:v 1000k -pass 2 output.mp4
Quick Reference
Common Options
| Option | Description | Example |
|---|---|---|
-i | Input file | -i input.mp4 |
-c:v | Video codec | -c:v libx264 |
-c:a | Audio codec | -c:a aac |
-c copy | Copy streams | -c copy |
-b:v | Video bitrate | -b:v 2M |
-b:a | Audio bitrate | -b:a 128k |
-crf | Quality (lower=better) | -crf 23 |
-preset | Encoding speed | -preset slow |
-vf | Video filter | -vf "scale=1280:720" |
-af | Audio filter | -af "volume=2.0" |
-ss | Start time | -ss 00:01:30 |
-t | Duration | -t 00:00:10 |
-to | End time | -to 00:02:00 |
-r | Frame rate | -r 30 |
-s | Resolution | -s 1920x1080 |
-an | No audio | -an |
-vn | No video | -vn |
-sn | No subtitles | -sn |
-map | Stream selection | -map 0:v:0 |
-y | Overwrite output | -y |
-n | Never overwrite | -n |
Codec Shortcuts
| Codec | Video | Audio |
|---|---|---|
| Copy | -c:v copy | -c:a copy |
| H.264 | -c:v libx264 | - |
| H.265 | -c:v libx265 | - |
| VP9 | -c:v libvpx-vp9 | - |
| AV1 | -c:v libaom-av1 | - |
| AAC | - | -c:a aac |
| MP3 | - | -c:a libmp3lame |
| Opus | - | -c:a libopus |
| Vorbis | - | -c:a libvorbis |
Quality Presets
| Preset | Speed | Quality |
|---|---|---|
| ultrafast | Fastest | Lowest |
| superfast | Very fast | Low |
| veryfast | Fast | Medium-low |
| faster | Moderate-fast | Medium |
| fast | Moderate | Good |
| medium | Moderate | Good (default) |
| slow | Slow | Very good |
| slower | Very slow | Excellent |
| veryslow | Slowest | Best |
CRF Values
| Codec | Range | Default | Recommended |
|---|---|---|---|
| H.264 | 0-51 | 23 | 18-28 |
| H.265 | 0-51 | 28 | 22-32 |
| VP9 | 0-63 | 30 | 15-35 |
| AV1 | 0-63 | 30 | 20-40 |
Useful Resources
- Official Documentation: https://ffmpeg.org/documentation.html
- Wiki: https://trac.ffmpeg.org/wiki
- Filters Documentation: https://ffmpeg.org/ffmpeg-filters.html
- Codecs: https://ffmpeg.org/ffmpeg-codecs.html
- Formats: https://ffmpeg.org/ffmpeg-formats.html
FFmpeg is an incredibly powerful tool with nearly limitless capabilities for audio and video processing. Master these patterns and you'll be able to handle virtually any multimedia task from the command line.
make
make is a build automation tool that automatically builds executable programs and libraries from source code by reading files called Makefiles which specify how to derive the target program.
Overview
make uses Makefiles to determine which parts of a program need to be recompiled and issues commands to rebuild them. It's particularly useful for managing dependencies in large projects.
Key Concepts:
- Target: The file to be created or action to be performed
- Prerequisites: Files that must exist before target can be built
- Recipe: Commands to create the target from prerequisites
- Rule: Combination of target, prerequisites, and recipe
- Phony Target: Target that doesn't represent a file
Basic Makefile
Simple Example
# Basic Makefile structure
target: prerequisites
recipe
# Example: Compile a C program
program: main.c
gcc -o program main.c
# Clean up build artifacts
clean:
rm -f program
Running make
# Build default target (first target in Makefile)
make
# Build specific target
make clean
# Build multiple targets
make program test
# Show commands without executing
make -n
# Run with specific Makefile
make -f MyMakefile
Makefile Syntax
Basic Structure
# Comments start with #
# Variable definition
CC = gcc
CFLAGS = -Wall -O2
# Rule with target, prerequisites, and recipe
program: main.o utils.o
$(CC) -o program main.o utils.o
# Multiple recipes (each on new line, indented with TAB)
main.o: main.c
@echo "Compiling main.c"
$(CC) $(CFLAGS) -c main.c
# Target with no prerequisites
clean:
rm -f *.o program
Important: Recipes must be indented with a TAB character, not spaces.
Variables
# Simple variable assignment
CC = gcc
CXX = g++
CFLAGS = -Wall -Wextra -O2
# Recursive expansion (evaluated when used)
SRCS = $(wildcard *.c)
OBJS = $(SRCS:.c=.o)
# Simple expansion (evaluated immediately)
NOW := $(shell date)
# Conditional assignment (only if not set)
CC ?= gcc
# Append to variable
CFLAGS += -g
# Using variables
program: main.c
$(CC) $(CFLAGS) -o program main.c
Automatic Variables
# $@ - Target name
# $< - First prerequisite
# $^ - All prerequisites
# $? - Prerequisites newer than target
# $* - Stem of pattern rule match
# Example usage
%.o: %.c
$(CC) $(CFLAGS) -c $< -o $@
# $< is the .c file
# $@ is the .o file
program: main.o utils.o
$(CC) -o $@ $^
# $@ is 'program'
# $^ is 'main.o utils.o'
Pattern Rules
Suffix Rules
# Pattern rule for .c -> .o
%.o: %.c
$(CC) $(CFLAGS) -c $< -o $@
# Pattern rule for .cpp -> .o
%.o: %.cpp
$(CXX) $(CXXFLAGS) -c $< -o $@
# Multiple wildcards
bin/%: src/%.c
$(CC) $(CFLAGS) $< -o $@
Wildcards
# Wildcard function
SRCS = $(wildcard src/*.c)
OBJS = $(wildcard obj/*.o)
# Pattern substitution
OBJS = $(SRCS:.c=.o)
OBJS = $(SRCS:%.c=%.o)
OBJS = $(patsubst %.c,%.o,$(SRCS))
# Example
SOURCES = $(wildcard *.c)
OBJECTS = $(SOURCES:.c=.o)
DEPS = $(SOURCES:.c=.d)
Phony Targets
# Declare phony targets
.PHONY: all clean install test
# Common phony targets
all: program library
clean:
rm -f *.o *.d program
install: program
cp program /usr/local/bin/
test: program
./run_tests.sh
# Prevent make from checking if 'clean' file exists
.PHONY: clean
clean:
rm -f *.o program
C/C++ Project Examples
Simple C Project
# Compiler and flags
CC = gcc
CFLAGS = -Wall -Wextra -O2 -g
# Target executable
TARGET = myprogram
# Source files
SRCS = main.c utils.c parser.c
OBJS = $(SRCS:.c=.o)
# Default target
all: $(TARGET)
# Link object files
$(TARGET): $(OBJS)
$(CC) -o $@ $^
# Compile source files
%.o: %.c
$(CC) $(CFLAGS) -c $< -o $@
# Clean build artifacts
clean:
rm -f $(OBJS) $(TARGET)
# Phony targets
.PHONY: all clean
C Project with Headers
CC = gcc
CFLAGS = -Wall -Wextra -O2 -Iinclude
SRCDIR = src
OBJDIR = obj
BINDIR = bin
SRCS = $(wildcard $(SRCDIR)/*.c)
OBJS = $(SRCS:$(SRCDIR)/%.c=$(OBJDIR)/%.o)
TARGET = $(BINDIR)/program
all: $(TARGET)
$(TARGET): $(OBJS) | $(BINDIR)
$(CC) -o $@ $^
$(OBJDIR)/%.o: $(SRCDIR)/%.c | $(OBJDIR)
$(CC) $(CFLAGS) -c $< -o $@
# Create directories if they don't exist
$(BINDIR) $(OBJDIR):
mkdir -p $@
clean:
rm -rf $(OBJDIR) $(BINDIR)
.PHONY: all clean
C++ Project with Libraries
CXX = g++
CXXFLAGS = -std=c++17 -Wall -Wextra -O2
LDFLAGS = -lpthread -lm
SRCDIR = src
OBJDIR = obj
BINDIR = bin
INCDIR = include
SRCS = $(wildcard $(SRCDIR)/*.cpp)
OBJS = $(SRCS:$(SRCDIR)/%.cpp=$(OBJDIR)/%.o)
DEPS = $(OBJS:.o=.d)
TARGET = $(BINDIR)/program
all: $(TARGET)
$(TARGET): $(OBJS) | $(BINDIR)
$(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS)
$(OBJDIR)/%.o: $(SRCDIR)/%.cpp | $(OBJDIR)
$(CXX) $(CXXFLAGS) -I$(INCDIR) -MMD -MP -c $< -o $@
$(BINDIR) $(OBJDIR):
mkdir -p $@
clean:
rm -rf $(OBJDIR) $(BINDIR)
# Include dependency files
-include $(DEPS)
.PHONY: all clean
Multi-target Project
CC = gcc
CFLAGS = -Wall -Wextra -O2
# Multiple programs
PROGRAMS = server client
all: $(PROGRAMS)
server: server.o network.o utils.o
$(CC) -o $@ $^
client: client.o network.o
$(CC) -o $@ $^
%.o: %.c
$(CC) $(CFLAGS) -c $< -o $@
clean:
rm -f *.o $(PROGRAMS)
.PHONY: all clean
Advanced Features
Conditional Statements
# Check variable value
ifdef DEBUG
CFLAGS += -g -DDEBUG
else
CFLAGS += -O2
endif
# Conditional based on value
ifeq ($(CC),gcc)
CFLAGS += -Wall
endif
ifneq ($(OS),Windows_NT)
LDFLAGS += -lpthread
endif
# OS detection
UNAME := $(shell uname -s)
ifeq ($(UNAME),Linux)
LDFLAGS += -lrt
endif
ifeq ($(UNAME),Darwin)
LDFLAGS += -framework CoreFoundation
endif
Functions
# Substitution
SRCS = main.c utils.c parser.c
OBJS = $(SRCS:.c=.o)
OBJS = $(patsubst %.c,%.o,$(SRCS))
# Directory operations
DIRS = $(dir src/main.c include/utils.h) # "src/ include/"
FILES = $(notdir src/main.c include/utils.h) # "main.c utils.h"
# String manipulation
FILES = $(wildcard *.c)
NAMES = $(basename $(FILES)) # Remove extension
UPPERS = $(shell echo $(FILES) | tr a-z A-Z)
# Filtering
SRCS = main.c test.c utils.c
PROD_SRCS = $(filter-out test.c,$(SRCS)) # "main.c utils.c"
TEST_SRCS = $(filter test%,$(SRCS)) # "test.c"
# Shell commands
DATE := $(shell date +%Y%m%d)
GIT_HASH := $(shell git rev-parse --short HEAD)
Include Directives
# Include another makefile
include config.mk
# Include with error if missing
include required.mk
# Include without error if missing
-include optional.mk
# Include all dependency files
-include $(DEPS)
# Example: config.mk
# CC = gcc
# CFLAGS = -Wall -O2
Recursive Make
# Top-level Makefile
SUBDIRS = lib src tests
all:
for dir in $(SUBDIRS); do \
$(MAKE) -C $$dir; \
done
clean:
for dir in $(SUBDIRS); do \
$(MAKE) -C $$dir clean; \
done
.PHONY: all clean
Dependency Generation
CC = gcc
CFLAGS = -Wall -O2
SRCS = main.c utils.c
OBJS = $(SRCS:.c=.o)
DEPS = $(SRCS:.c=.d)
program: $(OBJS)
$(CC) -o $@ $^
# Generate dependencies automatically
%.o: %.c
$(CC) $(CFLAGS) -MMD -MP -c $< -o $@
# Include generated dependency files
-include $(DEPS)
clean:
rm -f $(OBJS) $(DEPS) program
.PHONY: clean
Common Patterns
Debug and Release Builds
CC = gcc
CFLAGS = -Wall -Wextra
# Build modes
ifdef DEBUG
CFLAGS += -g -O0 -DDEBUG
TARGET = program_debug
else
CFLAGS += -O2 -DNDEBUG
TARGET = program
endif
SRCS = main.c utils.c
OBJS = $(SRCS:.c=.o)
all: $(TARGET)
$(TARGET): $(OBJS)
$(CC) -o $@ $^
%.o: %.c
$(CC) $(CFLAGS) -c $< -o $@
clean:
rm -f $(OBJS) program program_debug
# Usage: make DEBUG=1
.PHONY: all clean
Installation Targets
PREFIX = /usr/local
BINDIR = $(PREFIX)/bin
DATADIR = $(PREFIX)/share/myapp
all: program
program: main.o
$(CC) -o $@ $^
install: program
install -d $(BINDIR)
install -m 755 program $(BINDIR)
install -d $(DATADIR)
install -m 644 data/* $(DATADIR)
uninstall:
rm -f $(BINDIR)/program
rm -rf $(DATADIR)
.PHONY: all install uninstall
Test Targets
CC = gcc
CFLAGS = -Wall -Wextra -O2
SRCS = main.c utils.c
TEST_SRCS = test_utils.c
OBJS = $(SRCS:.c=.o)
TEST_OBJS = $(TEST_SRCS:.c=.o)
program: $(OBJS)
$(CC) -o $@ $^
test_runner: $(TEST_OBJS) utils.o
$(CC) -o $@ $^
test: test_runner
./test_runner
%.o: %.c
$(CC) $(CFLAGS) -c $< -o $@
clean:
rm -f $(OBJS) $(TEST_OBJS) program test_runner
.PHONY: test clean
Static Library
CC = gcc
AR = ar
CFLAGS = -Wall -Wextra -O2
LIBNAME = mylib
SRCS = lib1.c lib2.c lib3.c
OBJS = $(SRCS:.c=.o)
TARGET = lib$(LIBNAME).a
all: $(TARGET)
$(TARGET): $(OBJS)
$(AR) rcs $@ $^
%.o: %.c
$(CC) $(CFLAGS) -c $< -o $@
install: $(TARGET)
install -d /usr/local/lib
install -m 644 $(TARGET) /usr/local/lib
install -d /usr/local/include/$(LIBNAME)
install -m 644 *.h /usr/local/include/$(LIBNAME)
clean:
rm -f $(OBJS) $(TARGET)
.PHONY: all install clean
Shared Library
CC = gcc
CFLAGS = -Wall -Wextra -O2 -fPIC
LDFLAGS = -shared
LIBNAME = mylib
VERSION = 1.0.0
MAJOR = 1
SRCS = lib1.c lib2.c lib3.c
OBJS = $(SRCS:.c=.o)
TARGET = lib$(LIBNAME).so.$(VERSION)
SONAME = lib$(LIBNAME).so.$(MAJOR)
LINKNAME = lib$(LIBNAME).so
all: $(TARGET)
$(TARGET): $(OBJS)
$(CC) $(LDFLAGS) -Wl,-soname,$(SONAME) -o $@ $^
%.o: %.c
$(CC) $(CFLAGS) -c $< -o $@
install: $(TARGET)
install -d /usr/local/lib
install -m 755 $(TARGET) /usr/local/lib
ln -sf $(TARGET) /usr/local/lib/$(SONAME)
ln -sf $(SONAME) /usr/local/lib/$(LINKNAME)
ldconfig
clean:
rm -f $(OBJS) $(TARGET)
.PHONY: all install clean
Make Options
Common Flags
# Run in parallel (4 jobs)
make -j4
# Keep going on errors
make -k
# Show commands without executing
make -n
make --dry-run
# Print directory changes
make -w
# Ignore errors
make -i
# Touch files instead of building
make -t
# Print database of rules
make -p
# Treat warnings as errors
make --warn-undefined-variables
Environment Variables
# Override variables
make CC=clang CFLAGS="-O3"
# Use specific Makefile
make -f Makefile.custom
# Change directory
make -C src/
# Set variables in Makefile
export CC=gcc
make
Best Practices
Structure and Organization
# 1. Use variables for configurability
CC = gcc
CFLAGS = -Wall -Wextra -O2
PREFIX = /usr/local
# 2. Declare phony targets
.PHONY: all clean install test
# 3. Use automatic variables
%.o: %.c
$(CC) $(CFLAGS) -c $< -o $@
# 4. Add help target
help:
@echo "Available targets:"
@echo " all - Build the program"
@echo " clean - Remove build artifacts"
@echo " install - Install the program"
@echo " test - Run tests"
# 5. Use default goal
.DEFAULT_GOAL := all
Dependency Management
# Auto-generate dependencies
CC = gcc
CFLAGS = -Wall -O2
DEPFLAGS = -MMD -MP
SRCS = $(wildcard *.c)
OBJS = $(SRCS:.c=.o)
DEPS = $(SRCS:.c=.d)
%.o: %.c
$(CC) $(CFLAGS) $(DEPFLAGS) -c $< -o $@
-include $(DEPS)
clean:
rm -f $(OBJS) $(DEPS)
Error Handling
# Stop on first error (default behavior)
.POSIX:
# Check for required tools
CHECK_CC := $(shell command -v $(CC) 2> /dev/null)
ifndef CHECK_CC
$(error $(CC) not found in PATH)
endif
# Validate variables
ifndef TARGET
$(error TARGET is not defined)
endif
# Conditional compilation
program: main.o
ifeq ($(CC),)
$(error CC is not set)
endif
$(CC) -o $@ $^
Silent and Verbose Modes
# Silent mode (suppress echo of commands)
.SILENT:
# Selective silence
all:
@echo "Building..."
$(CC) -o program main.c
# Verbose mode controlled by variable
ifdef VERBOSE
Q =
else
Q = @
endif
%.o: %.c
@echo "CC $<"
$(Q)$(CC) $(CFLAGS) -c $< -o $@
Troubleshooting
Common Issues
# "Missing separator" error
# Problem: Using spaces instead of TAB in recipe
# Solution: Ensure recipes are indented with TAB
# "No rule to make target" error
# Problem: Make can't find prerequisite file
make --debug=v # Verbose debug output
# "Circular dependency" error
# Problem: Target depends on itself
# Solution: Review dependency chain
# Rebuild everything
make clean && make
# Show what make would do
make -n
# Print variables
make print-VARIABLE
Debug Makefile
# Print variable values
print-%:
@echo $* = $($*)
# Usage: make print-CFLAGS
# Debug output
$(info Building with CC=$(CC))
$(warning This is a warning message)
$(error This stops the build)
# Show all variables
debug:
@echo "SRCS = $(SRCS)"
@echo "OBJS = $(OBJS)"
@echo "CFLAGS = $(CFLAGS)"
Performance Optimization
# Parallel builds
make -j$(nproc) # Use all CPU cores
# Profile make execution
make -d > debug.log 2>&1
# Check which targets are rebuilt
make -d | grep "Must remake"
# Use ccache for faster compilation
CC = ccache gcc
Complete Example
# Project configuration
PROJECT = myapp
VERSION = 1.0.0
# Compiler settings
CC = gcc
CXX = g++
CFLAGS = -Wall -Wextra -std=c11 -O2
CXXFLAGS = -Wall -Wextra -std=c++17 -O2
LDFLAGS = -lm -lpthread
# Directories
SRCDIR = src
INCDIR = include
OBJDIR = obj
BINDIR = bin
TESTDIR = tests
# Files
SRCS = $(wildcard $(SRCDIR)/*.c)
OBJS = $(SRCS:$(SRCDIR)/%.c=$(OBJDIR)/%.o)
DEPS = $(OBJS:.o=.d)
TARGET = $(BINDIR)/$(PROJECT)
# Installation paths
PREFIX = /usr/local
BINPREFIX = $(PREFIX)/bin
# Build modes
ifdef DEBUG
CFLAGS += -g -DDEBUG
CXXFLAGS += -g -DDEBUG
endif
ifdef VERBOSE
Q =
else
Q = @
endif
# Targets
.PHONY: all clean install uninstall test help
all: $(TARGET)
$(TARGET): $(OBJS) | $(BINDIR)
@echo "Linking $@"
$(Q)$(CC) -o $@ $^ $(LDFLAGS)
$(OBJDIR)/%.o: $(SRCDIR)/%.c | $(OBJDIR)
@echo "Compiling $<"
$(Q)$(CC) $(CFLAGS) -I$(INCDIR) -MMD -MP -c $< -o $@
$(BINDIR) $(OBJDIR):
$(Q)mkdir -p $@
clean:
@echo "Cleaning build artifacts"
$(Q)rm -rf $(OBJDIR) $(BINDIR)
install: $(TARGET)
@echo "Installing to $(BINPREFIX)"
$(Q)install -d $(BINPREFIX)
$(Q)install -m 755 $(TARGET) $(BINPREFIX)
uninstall:
@echo "Uninstalling from $(BINPREFIX)"
$(Q)rm -f $(BINPREFIX)/$(PROJECT)
test: $(TARGET)
@echo "Running tests"
$(Q)./$(TESTDIR)/run_tests.sh
help:
@echo "Available targets:"
@echo " all - Build the project (default)"
@echo " clean - Remove build artifacts"
@echo " install - Install the program"
@echo " uninstall - Uninstall the program"
@echo " test - Run tests"
@echo " help - Show this help message"
@echo ""
@echo "Build modes:"
@echo " make DEBUG=1 - Build with debug symbols"
@echo " make VERBOSE=1 - Show full commands"
-include $(DEPS)
Useful Tips
- Always use
.PHONYfor non-file targets - Use automatic variables (
$@,$<,$^) for maintainability - Generate dependencies automatically with
-MMD -MP - Support parallel builds with
make -j - Use variables for all configuration options
- Include help target for user guidance
- Handle errors gracefully with proper checks
- Keep Makefiles readable with comments and organization
make simplifies building complex projects by managing dependencies and minimizing rebuild time, making it an essential tool for C/C++ development and beyond.
Docker
Docker is a platform for developing, shipping, and running applications in containers. Containers package software with all dependencies, ensuring consistent behavior across different environments.
Overview
Docker enables developers to package applications with their dependencies into standardized units called containers, which can run anywhere Docker is installed.
Key Concepts:
- Container: Lightweight, standalone executable package
- Image: Read-only template for creating containers
- Dockerfile: Script defining how to build an image
- Registry: Repository for storing and distributing images
- Docker Hub: Public registry for Docker images
Installation
# Ubuntu/Debian
curl -fsSL https://get.docker.com -o get-docker.sh
sudo sh get-docker.sh
sudo usermod -aG docker $USER
# Verify installation
docker --version
docker run hello-world
Basic Commands
Container Operations
# Run a container
docker run nginx
docker run -d nginx # Detached mode
docker run -it ubuntu bash # Interactive terminal
# Run with options
docker run -d \
--name my-nginx \
-p 8080:80 \
-v /host/path:/container/path \
-e ENV_VAR=value \
nginx
# List containers
docker ps # Running containers
docker ps -a # All containers
# Stop/Start containers
docker stop container_name
docker start container_name
docker restart container_name
# Remove containers
docker rm container_name
docker rm -f container_name # Force remove
docker container prune # Remove all stopped containers
Image Operations
# List images
docker images
docker image ls
# Pull image from registry
docker pull nginx
docker pull nginx:1.21
# Build image from Dockerfile
docker build -t myapp:1.0 .
docker build -t myapp:latest -f Dockerfile.prod .
# Remove images
docker rmi image_name
docker image prune # Remove unused images
docker image prune -a # Remove all unused images
# Tag image
docker tag myapp:1.0 username/myapp:1.0
# Push to registry
docker push username/myapp:1.0
Logs and Debugging
# View logs
docker logs container_name
docker logs -f container_name # Follow logs
docker logs --tail 100 container_name
# Execute command in container
docker exec container_name ls /app
docker exec -it container_name bash
# Inspect container
docker inspect container_name
docker stats container_name # Resource usage
# Copy files
docker cp file.txt container_name:/path/
docker cp container_name:/path/file.txt ./
Dockerfile
Basic Dockerfile
# Base image
FROM node:18-alpine
# Set working directory
WORKDIR /app
# Copy package files
COPY package*.json ./
# Install dependencies
RUN npm install
# Copy application code
COPY . .
# Expose port
EXPOSE 3000
# Set environment variables
ENV NODE_ENV=production
# Run command
CMD ["node", "server.js"]
Multi-stage Build
# Build stage
FROM node:18-alpine AS builder
WORKDIR /app
COPY package*.json ./
RUN npm install
COPY . .
RUN npm run build
# Production stage
FROM node:18-alpine
WORKDIR /app
COPY --from=builder /app/dist ./dist
COPY package*.json ./
RUN npm install --production
EXPOSE 3000
CMD ["node", "dist/server.js"]
Dockerfile Instructions
# FROM: Base image
FROM ubuntu:22.04
# LABEL: Metadata
LABEL maintainer="dev@example.com"
LABEL version="1.0"
# ENV: Environment variables
ENV APP_HOME=/app
ENV PORT=8080
# ARG: Build-time variables
ARG VERSION=latest
RUN echo "Building version ${VERSION}"
# WORKDIR: Set working directory
WORKDIR /app
# COPY: Copy files from host
COPY src/ /app/src/
# ADD: Copy and extract archives
ADD archive.tar.gz /app/
# RUN: Execute commands during build
RUN apt-get update && \
apt-get install -y python3 && \
rm -rf /var/lib/apt/lists/*
# USER: Set user
USER appuser
# EXPOSE: Document ports
EXPOSE 8080 8443
# VOLUME: Create mount point
VOLUME ["/data"]
# ENTRYPOINT: Configure container executable
ENTRYPOINT ["python3"]
# CMD: Default arguments for ENTRYPOINT
CMD ["app.py"]
# HEALTHCHECK: Container health check
HEALTHCHECK --interval=30s --timeout=3s \
CMD curl -f http://localhost/ || exit 1
Docker Compose
Basic docker-compose.yml
version: '3.8'
services:
web:
build: .
ports:
- "8080:80"
volumes:
- ./src:/app/src
environment:
- NODE_ENV=development
depends_on:
- db
db:
image: postgres:15
volumes:
- db-data:/var/lib/postgresql/data
environment:
POSTGRES_PASSWORD: secret
POSTGRES_DB: myapp
volumes:
db-data:
Docker Compose Commands
# Start services
docker-compose up
docker-compose up -d # Detached
# Stop services
docker-compose down
docker-compose down -v # Remove volumes
# Build services
docker-compose build
docker-compose build --no-cache
# View logs
docker-compose logs
docker-compose logs -f service_name
# Execute commands
docker-compose exec web bash
docker-compose exec db psql -U postgres
# Scale services
docker-compose up -d --scale web=3
# List services
docker-compose ps
Advanced Compose Configuration
version: '3.8'
services:
app:
build:
context: .
dockerfile: Dockerfile.dev
args:
VERSION: "1.0"
image: myapp:latest
container_name: myapp
restart: unless-stopped
ports:
- "3000:3000"
volumes:
- ./src:/app/src:ro # Read-only
- node_modules:/app/node_modules
environment:
NODE_ENV: development
DATABASE_URL: postgres://db:5432/myapp
env_file:
- .env
depends_on:
db:
condition: service_healthy
networks:
- backend
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:3000/health"]
interval: 30s
timeout: 10s
retries: 3
db:
image: postgres:15-alpine
volumes:
- postgres_data:/var/lib/postgresql/data
- ./init.sql:/docker-entrypoint-initdb.d/init.sql
environment:
POSTGRES_PASSWORD: ${DB_PASSWORD}
networks:
- backend
networks:
backend:
driver: bridge
volumes:
postgres_data:
node_modules:
Networking
Network Commands
# List networks
docker network ls
# Create network
docker network create mynetwork
docker network create --driver bridge mynetwork
# Connect container to network
docker network connect mynetwork container_name
# Disconnect from network
docker network disconnect mynetwork container_name
# Inspect network
docker network inspect mynetwork
# Remove network
docker network rm mynetwork
Network Types
# Bridge (default)
docker run --network bridge nginx
# Host (use host's network)
docker run --network host nginx
# None (no networking)
docker run --network none nginx
# Custom bridge network
docker network create app-network
docker run --network app-network --name web nginx
docker run --network app-network --name db postgres
Volumes
Volume Management
# Create volume
docker volume create myvolume
# List volumes
docker volume ls
# Inspect volume
docker volume inspect myvolume
# Remove volume
docker volume rm myvolume
docker volume prune # Remove unused volumes
# Use volume in container
docker run -v myvolume:/data nginx
docker run --mount source=myvolume,target=/data nginx
Volume Types
# Named volume
docker run -v myvolume:/app/data nginx
# Bind mount (host directory)
docker run -v /host/path:/container/path nginx
docker run -v $(pwd):/app nginx
# Anonymous volume
docker run -v /container/path nginx
# Read-only volume
docker run -v myvolume:/data:ro nginx
Best Practices
Dockerfile Optimization
# 1. Use specific image tags
FROM node:18.16-alpine # Good
FROM node:latest # Avoid
# 2. Minimize layers
RUN apt-get update && apt-get install -y \
package1 \
package2 \
&& rm -rf /var/lib/apt/lists/*
# 3. Order instructions by frequency of change
FROM node:18-alpine
WORKDIR /app
COPY package*.json ./ # Changes less frequently
RUN npm install
COPY . . # Changes more frequently
# 4. Use .dockerignore
# Create .dockerignore file:
# node_modules
# .git
# .env
# *.log
# 5. Don't run as root
RUN addgroup -g 1001 appgroup && \
adduser -D -u 1001 -G appgroup appuser
USER appuser
# 6. Use multi-stage builds
FROM node:18 AS builder
WORKDIR /app
COPY . .
RUN npm run build
FROM node:18-alpine
COPY --from=builder /app/dist /app/dist
Security Best Practices
# 1. Scan images for vulnerabilities
docker scan myimage:latest
# 2. Use official images
docker pull nginx:alpine
# 3. Keep images updated
docker pull nginx:latest
# 4. Limit container resources
docker run --memory="512m" --cpus="1.0" nginx
# 5. Run as non-root user
docker run --user 1000:1000 nginx
# 6. Use secrets for sensitive data
docker secret create db_password password.txt
docker service create --secret db_password myapp
Common Patterns
Development Environment
# docker-compose.dev.yml
version: '3.8'
services:
app:
build:
context: .
target: development
volumes:
- .:/app
- /app/node_modules
ports:
- "3000:3000"
environment:
NODE_ENV: development
command: npm run dev
Production Setup
# docker-compose.prod.yml
version: '3.8'
services:
app:
image: myapp:${VERSION:-latest}
restart: always
ports:
- "80:3000"
environment:
NODE_ENV: production
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:3000/health"]
interval: 30s
deploy:
replicas: 3
resources:
limits:
cpus: '1'
memory: 512M
Backup Script
#!/bin/bash
# Backup Docker volume
VOLUME_NAME="mydata"
BACKUP_FILE="backup-$(date +%Y%m%d-%H%M%S).tar.gz"
docker run --rm \
-v ${VOLUME_NAME}:/data \
-v $(pwd):/backup \
alpine \
tar czf /backup/${BACKUP_FILE} -C /data .
echo "Backup created: ${BACKUP_FILE}"
Troubleshooting
Common Issues
# Container exits immediately
docker logs container_name
docker run -it container_name sh
# Port already in use
docker ps -a | grep 8080
lsof -i :8080
# Out of disk space
docker system df
docker system prune # Remove unused data
docker system prune -a # Remove all unused data
# Permission denied
sudo usermod -aG docker $USER
newgrp docker
# Network issues
docker network ls
docker network inspect bridge
# Image pull errors
docker pull --platform linux/amd64 image_name
Debugging Commands
# Inspect container
docker inspect --format='{{.State.Status}}' container_name
docker inspect --format='{{.NetworkSettings.IPAddress}}' container_name
# Container events
docker events --filter container=container_name
# System information
docker info
docker version
# Resource usage
docker stats
docker top container_name
Useful Aliases
# Add to ~/.bashrc or ~/.zshrc
alias dps='docker ps'
alias dpsa='docker ps -a'
alias di='docker images'
alias drm='docker rm'
alias drmi='docker rmi'
alias dstop='docker stop $(docker ps -q)'
alias dclean='docker system prune -af'
alias dlog='docker logs -f'
alias dexec='docker exec -it'
Quick Reference
| Command | Description |
|---|---|
docker run | Create and start container |
docker ps | List running containers |
docker stop | Stop container |
docker rm | Remove container |
docker images | List images |
docker pull | Download image |
docker build | Build image from Dockerfile |
docker push | Upload image to registry |
docker logs | View container logs |
docker exec | Run command in container |
docker-compose up | Start services |
docker-compose down | Stop services |
Docker simplifies application deployment and ensures consistency across development, testing, and production environments.
Ansible
Ansible is an open-source IT automation engine that automates provisioning, configuration management, application deployment, orchestration, and many other IT processes. It uses SSH for communication and requires no agents on managed nodes.
Overview
Ansible uses a simple, human-readable language (YAML) to describe automation jobs. It's agentless, using OpenSSH for transport, making it secure and easy to set up.
Key Concepts:
- Inventory: List of managed nodes (hosts)
- Playbook: YAML files defining tasks to execute
- Module: Reusable code units for specific tasks
- Role: Organized collection of playbooks and files
- Task: Single action to be performed
- Handler: Tasks triggered by notifications
- Facts: System information gathered from hosts
Installation
# Ubuntu/Debian
sudo apt update
sudo apt install ansible
# macOS
brew install ansible
# CentOS/RHEL
sudo yum install epel-release
sudo yum install ansible
# Using pip
pip install ansible
# Verify installation
ansible --version
Basic Configuration
Ansible Config
# Create ansible.cfg
cat << 'EOF' > ansible.cfg
[defaults]
inventory = ./inventory
host_key_checking = False
remote_user = ansible
private_key_file = ~/.ssh/id_rsa
retry_files_enabled = False
gathering = smart
fact_caching = jsonfile
fact_caching_connection = /tmp/ansible_facts
fact_caching_timeout = 3600
[privilege_escalation]
become = True
become_method = sudo
become_user = root
become_ask_pass = False
EOF
Inventory File
# inventory or hosts file
# Single host
web1.example.com
# Group of hosts
[webservers]
web1.example.com
web2.example.com
192.168.1.10
# Multiple groups
[databases]
db1.example.com
db2.example.com
[app:children]
webservers
databases
# Host with variables
[webservers]
web1.example.com ansible_user=admin ansible_port=2222
# Group variables
[webservers:vars]
ansible_user=deploy
ansible_python_interpreter=/usr/bin/python3
http_port=80
Dynamic Inventory (YAML)
# inventory.yml
all:
hosts:
web1.example.com:
web2.example.com:
children:
webservers:
hosts:
web1.example.com:
ansible_user: deploy
web2.example.com:
ansible_user: deploy
vars:
http_port: 80
databases:
hosts:
db1.example.com:
db2.example.com:
vars:
db_port: 5432
Ad-hoc Commands
# Ping all hosts
ansible all -m ping
# Ping specific group
ansible webservers -m ping
# Run shell command
ansible all -m shell -a "uptime"
ansible webservers -a "df -h" # shell module is default
# Copy file
ansible all -m copy -a "src=/local/file dest=/remote/file"
# Install package
ansible webservers -m apt -a "name=nginx state=present" --become
# Start service
ansible webservers -m service -a "name=nginx state=started" --become
# Gather facts
ansible all -m setup
# Specific fact
ansible all -m setup -a "filter=ansible_distribution*"
# Execute with sudo
ansible all -a "systemctl restart nginx" --become
# Execute as specific user
ansible all -a "whoami" --become-user=www-data
Playbooks
Basic Playbook
# playbook.yml
---
- name: Configure web servers
hosts: webservers
become: yes
tasks:
- name: Install nginx
apt:
name: nginx
state: present
update_cache: yes
- name: Start nginx service
service:
name: nginx
state: started
enabled: yes
- name: Copy index.html
copy:
src: files/index.html
dest: /var/www/html/index.html
owner: www-data
group: www-data
mode: '0644'
Running Playbooks
# Run playbook
ansible-playbook playbook.yml
# Dry run (check mode)
ansible-playbook playbook.yml --check
# Show differences
ansible-playbook playbook.yml --check --diff
# Limit to specific hosts
ansible-playbook playbook.yml --limit web1.example.com
ansible-playbook playbook.yml --limit webservers
# Tags
ansible-playbook playbook.yml --tags "install"
ansible-playbook playbook.yml --skip-tags "config"
# Start at specific task
ansible-playbook playbook.yml --start-at-task="Install nginx"
# Verbose output
ansible-playbook playbook.yml -v # verbose
ansible-playbook playbook.yml -vv # more verbose
ansible-playbook playbook.yml -vvv # very verbose
Variables in Playbooks
---
- name: Configure application
hosts: webservers
vars:
app_name: myapp
app_version: "1.0"
app_port: 8080
tasks:
- name: Create app directory
file:
path: "/opt/{{ app_name }}"
state: directory
owner: "{{ ansible_user }}"
- name: Display variables
debug:
msg: "Deploying {{ app_name }} version {{ app_version }} on port {{ app_port }}"
Variables from Files
# vars.yml
---
app_name: myapp
app_version: "1.0"
app_port: 8080
database:
host: db.example.com
name: myapp_db
user: myapp_user
# playbook.yml
---
- name: Configure application
hosts: webservers
vars_files:
- vars.yml
tasks:
- name: Display app info
debug:
msg: "App: {{ app_name }}, DB: {{ database.host }}"
Common Modules
System Modules
# User management
- name: Create user
user:
name: deploy
state: present
groups: sudo
shell: /bin/bash
create_home: yes
# Group management
- name: Create group
group:
name: developers
state: present
# File operations
- name: Create file
file:
path: /tmp/test.txt
state: touch
mode: '0644'
owner: deploy
- name: Create directory
file:
path: /opt/myapp
state: directory
mode: '0755'
recurse: yes
# Copy files
- name: Copy file
copy:
src: files/config.conf
dest: /etc/myapp/config.conf
backup: yes
# Template files
- name: Deploy template
template:
src: templates/nginx.conf.j2
dest: /etc/nginx/nginx.conf
validate: 'nginx -t -c %s'
notify: restart nginx
Package Management
# APT (Debian/Ubuntu)
- name: Install packages
apt:
name:
- nginx
- postgresql
- python3-pip
state: present
update_cache: yes
# YUM/DNF (RedHat/CentOS)
- name: Install packages
yum:
name:
- httpd
- mariadb-server
state: present
# Package from URL
- name: Install deb package
apt:
deb: https://example.com/package.deb
# Remove package
- name: Remove package
apt:
name: apache2
state: absent
purge: yes
Service Management
- name: Manage service
service:
name: nginx
state: started
enabled: yes
- name: Restart service
service:
name: apache2
state: restarted
- name: Reload service
service:
name: nginx
state: reloaded
Command Execution
# Shell module
- name: Run shell command
shell: echo $HOME
register: home_dir
- name: Display output
debug:
var: home_dir.stdout
# Command module (no shell features)
- name: Run command
command: /usr/bin/uptime
register: uptime_result
# Script execution
- name: Run script
script: scripts/setup.sh
# Execute with conditions
- name: Check file exists
stat:
path: /etc/config.conf
register: config_file
- name: Run if file exists
command: /usr/bin/process_config
when: config_file.stat.exists
Handlers
---
- name: Configure nginx
hosts: webservers
become: yes
tasks:
- name: Copy nginx config
template:
src: nginx.conf.j2
dest: /etc/nginx/nginx.conf
notify:
- restart nginx
- reload nginx
- name: Copy site config
template:
src: site.conf.j2
dest: /etc/nginx/sites-available/default
notify: reload nginx
handlers:
- name: restart nginx
service:
name: nginx
state: restarted
- name: reload nginx
service:
name: nginx
state: reloaded
Roles
Creating a Role
# Create role structure
ansible-galaxy init myrole
# Directory structure
myrole/
├── defaults/ # Default variables
│ └── main.yml
├── files/ # Static files
├── handlers/ # Handlers
│ └── main.yml
├── meta/ # Role metadata
│ └── main.yml
├── tasks/ # Main tasks
│ └── main.yml
├── templates/ # Jinja2 templates
├── tests/ # Test playbooks
│ └── test.yml
└── vars/ # Role variables
└── main.yml
Role Example
# roles/nginx/tasks/main.yml
---
- name: Install nginx
apt:
name: nginx
state: present
update_cache: yes
- name: Copy nginx config
template:
src: nginx.conf.j2
dest: /etc/nginx/nginx.conf
notify: restart nginx
- name: Start nginx
service:
name: nginx
state: started
enabled: yes
# roles/nginx/handlers/main.yml
---
- name: restart nginx
service:
name: nginx
state: restarted
# roles/nginx/defaults/main.yml
---
nginx_port: 80
nginx_user: www-data
# Using the role
---
- name: Setup web server
hosts: webservers
become: yes
roles:
- nginx
- { role: mysql, mysql_port: 3306 }
Conditionals and Loops
Conditionals
---
- name: Conditional tasks
hosts: all
tasks:
- name: Install on Ubuntu
apt:
name: nginx
state: present
when: ansible_distribution == "Ubuntu"
- name: Install on CentOS
yum:
name: httpd
state: present
when: ansible_distribution == "CentOS"
- name: Multiple conditions (AND)
apt:
name: nginx
state: present
when:
- ansible_distribution == "Ubuntu"
- ansible_distribution_version == "20.04"
- name: Multiple conditions (OR)
apt:
name: nginx
state: present
when: ansible_distribution == "Ubuntu" or ansible_distribution == "Debian"
Loops
---
- name: Loop examples
hosts: all
tasks:
# Simple loop
- name: Install multiple packages
apt:
name: "{{ item }}"
state: present
loop:
- nginx
- postgresql
- redis-server
# Loop with dictionary
- name: Create users
user:
name: "{{ item.name }}"
groups: "{{ item.groups }}"
state: present
loop:
- { name: 'alice', groups: 'developers' }
- { name: 'bob', groups: 'admins' }
# Loop with complex data
- name: Create directories
file:
path: "{{ item.path }}"
state: directory
owner: "{{ item.owner }}"
mode: "{{ item.mode }}"
loop:
- { path: '/opt/app1', owner: 'deploy', mode: '0755' }
- { path: '/opt/app2', owner: 'www-data', mode: '0750' }
Templates
{# templates/nginx.conf.j2 #}
user {{ nginx_user }};
worker_processes {{ ansible_processor_vcpus }};
events {
worker_connections 1024;
}
http {
server {
listen {{ nginx_port }};
server_name {{ ansible_hostname }};
location / {
root /var/www/html;
index index.html;
}
}
}
{# Conditional content #}
{% if enable_ssl %}
ssl on;
ssl_certificate {{ ssl_cert_path }};
{% endif %}
{# Loop in template #}
{% for server in backend_servers %}
upstream backend_{{ loop.index }} {
server {{ server.host }}:{{ server.port }};
}
{% endfor %}
Vault (Encryption)
# Create encrypted file
ansible-vault create secrets.yml
# Edit encrypted file
ansible-vault edit secrets.yml
# Encrypt existing file
ansible-vault encrypt vars.yml
# Decrypt file
ansible-vault decrypt vars.yml
# View encrypted file
ansible-vault view secrets.yml
# Change password
ansible-vault rekey secrets.yml
# Use vault in playbook
ansible-playbook playbook.yml --ask-vault-pass
# Use password file
ansible-playbook playbook.yml --vault-password-file ~/.vault_pass
# Multiple vaults
ansible-playbook playbook.yml --vault-id prod@prompt --vault-id dev@~/.vault_pass_dev
Vault Example
# secrets.yml (encrypted)
db_password: "super_secret_password"
api_key: "abc123xyz789"
# playbook.yml
---
- name: Deploy with secrets
hosts: webservers
vars_files:
- secrets.yml
tasks:
- name: Configure database
template:
src: db_config.j2
dest: /etc/app/db_config.conf
Best Practices
Playbook Organization
# Recommended directory structure
site.yml # Master playbook
webservers.yml # Webserver playbook
dbservers.yml # Database playbook
inventory/
├── production/
│ ├── hosts
│ └── group_vars/
│ ├── all.yml
│ ├── webservers.yml
│ └── dbservers.yml
└── staging/
├── hosts
└── group_vars/
roles/
├── common/
├── nginx/
├── postgresql/
└── app/
group_vars/
├── all.yml
├── webservers.yml
└── dbservers.yml
host_vars/
└── web1.example.com.yml
Best Practices
# 1. Use names for all tasks
- name: Install nginx
apt:
name: nginx
state: present
# 2. Use become appropriately
- name: System task
become: yes
apt:
name: nginx
state: present
# 3. Validate configurations
- name: Deploy nginx config
template:
src: nginx.conf.j2
dest: /etc/nginx/nginx.conf
validate: 'nginx -t -c %s'
# 4. Use check mode compatible tasks
- name: Check if service exists
stat:
path: /etc/systemd/system/myapp.service
register: service_file
check_mode: no
# 5. Add tags
- name: Install packages
apt:
name: nginx
state: present
tags: ['install', 'packages']
# 6. Use blocks for error handling
- block:
- name: Risky operation
command: /usr/bin/risky_command
rescue:
- name: Handle error
debug:
msg: "Command failed, handling gracefully"
always:
- name: Cleanup
file:
path: /tmp/temp_file
state: absent
Common Patterns
Complete Web Server Setup
---
- name: Configure web servers
hosts: webservers
become: yes
vars:
app_name: myapp
app_user: www-data
tasks:
- name: Update apt cache
apt:
update_cache: yes
cache_valid_time: 3600
- name: Install packages
apt:
name:
- nginx
- python3-pip
- git
state: present
- name: Create app directory
file:
path: "/var/www/{{ app_name }}"
state: directory
owner: "{{ app_user }}"
mode: '0755'
- name: Deploy nginx config
template:
src: nginx.conf.j2
dest: /etc/nginx/sites-available/{{ app_name }}
notify: reload nginx
- name: Enable site
file:
src: /etc/nginx/sites-available/{{ app_name }}
dest: /etc/nginx/sites-enabled/{{ app_name }}
state: link
notify: reload nginx
- name: Start nginx
service:
name: nginx
state: started
enabled: yes
handlers:
- name: reload nginx
service:
name: nginx
state: reloaded
Multi-stage Deployment
---
- name: Deploy application
hosts: webservers
serial: 1 # Rolling update
max_fail_percentage: 25
pre_tasks:
- name: Remove from load balancer
haproxy:
state: disabled
host: "{{ ansible_hostname }}"
tasks:
- name: Deploy application
git:
repo: https://github.com/user/app.git
dest: /opt/app
version: "{{ app_version }}"
- name: Install dependencies
pip:
requirements: /opt/app/requirements.txt
- name: Restart app service
service:
name: myapp
state: restarted
- name: Wait for app to start
wait_for:
port: 8080
delay: 5
timeout: 30
post_tasks:
- name: Add to load balancer
haproxy:
state: enabled
host: "{{ ansible_hostname }}"
Troubleshooting
# Check syntax
ansible-playbook playbook.yml --syntax-check
# List tasks
ansible-playbook playbook.yml --list-tasks
# List hosts
ansible-playbook playbook.yml --list-hosts
# Dry run
ansible-playbook playbook.yml --check
# Debug mode
ansible-playbook playbook.yml -vvv
# Start at specific task
ansible-playbook playbook.yml --start-at-task="Install nginx"
# Step through playbook
ansible-playbook playbook.yml --step
# Gather facts only
ansible all -m setup --tree /tmp/facts
Quick Reference
| Command | Description |
|---|---|
ansible all -m ping | Ping all hosts |
ansible-playbook playbook.yml | Run playbook |
ansible-playbook --check | Dry run |
ansible-playbook --tags TAG | Run specific tags |
ansible-playbook --limit HOST | Limit to hosts |
ansible-vault create FILE | Create encrypted file |
ansible-galaxy init ROLE | Create role |
ansible-inventory --list | Show inventory |
Ansible simplifies IT automation with its agentless architecture and simple YAML syntax, making infrastructure management efficient and reproducible.
wpa_supplicant
A comprehensive guide to wpa_supplicant, the IEEE 802.11 authentication daemon for WiFi client connectivity on Linux.
Table of Contents
- Overview
- Installation
- Configuration Files
- Basic Usage
- Network Configuration
- Command-Line Interface
- wpa_cli Interactive Mode
- Advanced Configuration
- Security Modes
- Enterprise WiFi (802.1X)
- P2P WiFi Direct
- Troubleshooting
- Integration with systemd
- Best Practices
Overview
wpa_supplicant is a WPA/WPA2/WPA3 supplicant for Linux and other UNIX-like operating systems. It handles WiFi authentication and association for client stations.
Key Features
- WPA/WPA2/WPA3-Personal (PSK)
- WPA/WPA2/WPA3-Enterprise (802.1X/EAP)
- WEP (deprecated, for legacy networks)
- Hotspot 2.0 (Passpoint)
- WiFi Protected Setup (WPS)
- WiFi Direct (P2P)
- Automatic network selection
- Dynamic reconfiguration via control interface
Architecture
┌─────────────────────────────────────┐
│ User Space │
│ │
│ ┌──────────┐ ┌──────────┐ │
│ │ wpa_cli │ │ NetworkMgr│ │
│ └────┬─────┘ └────┬──────┘ │
│ │ │ │
│ └────┬────────────┘ │
│ │ Control socket │
│ ┌────▼──────────────┐ │
│ │ wpa_supplicant │ │
│ └────┬──────────────┘ │
│ │ nl80211/WEXT │
└────────────┼─────────────────────────┘
│
┌────────────▼─────────────────────────┐
│ Kernel Space │
│ ┌──────────────────────────┐ │
│ │ cfg80211 / mac80211 │ │
│ └──────────┬───────────────┘ │
│ │ │
│ ┌──────────▼───────────────┐ │
│ │ WiFi Driver │ │
│ └──────────┬───────────────┘ │
└─────────────┼──────────────────────────┘
│
┌──────▼──────┐
│ WiFi Hardware│
└─────────────┘
Installation
Debian/Ubuntu
sudo apt-get update
sudo apt-get install wpasupplicant
# Verify installation
wpa_supplicant -v
Fedora/RHEL/CentOS
sudo dnf install wpa_supplicant
# Or for older systems
sudo yum install wpa_supplicant
Arch Linux
sudo pacman -S wpa_supplicant
Build from Source
# Download
git clone git://w1.fi/srv/git/hostap.git
cd hostap/wpa_supplicant
# Configure
cp defconfig .config
# Edit .config to enable features
# Build
make
# Install
sudo make install
Configuration Files
Main Configuration File
Location: /etc/wpa_supplicant/wpa_supplicant.conf
Basic structure:
# Global settings
ctrl_interface=/var/run/wpa_supplicant
ctrl_interface_group=netdev
update_config=1
country=US
# Network configurations
network={
ssid="MyNetwork"
psk="password123"
}
File Permissions
# Secure the configuration file
sudo chmod 600 /etc/wpa_supplicant/wpa_supplicant.conf
sudo chown root:root /etc/wpa_supplicant/wpa_supplicant.conf
Global Parameters
# Control interface for wpa_cli
ctrl_interface=/var/run/wpa_supplicant
# Group that can access control interface
ctrl_interface_group=netdev
# Allow wpa_supplicant to update configuration
update_config=1
# Country code (affects regulatory domain)
country=US
# AP scanning mode
# 0 = driver takes care of scanning
# 1 = wpa_supplicant controls scanning (default)
# 2 = like 1, but use security policy
ap_scan=1
# Fast reauth for 802.1X
fast_reauth=1
# Enable P2P support
p2p_disabled=0
Basic Usage
Starting wpa_supplicant
# Basic usage
sudo wpa_supplicant -B -i wlan0 -c /etc/wpa_supplicant/wpa_supplicant.conf
# Options:
# -B: Run in background (daemon mode)
# -i: Network interface
# -c: Configuration file
# -D: Driver (nl80211, wext, etc.) - usually auto-detected
# -d: Enable debug output
# -dd: More verbose debug
Starting with Debug Output
# Foreground with debug
sudo wpa_supplicant -i wlan0 -c /etc/wpa_supplicant/wpa_supplicant.conf -d
# Even more verbose
sudo wpa_supplicant -i wlan0 -c /etc/wpa_supplicant/wpa_supplicant.conf -dd
Stopping wpa_supplicant
# Find process
ps aux | grep wpa_supplicant
# Kill process
sudo killall wpa_supplicant
# Or using systemd
sudo systemctl stop wpa_supplicant@wlan0
Manual Connection Workflow
# 1. Bring interface up
sudo ip link set wlan0 up
# 2. Start wpa_supplicant
sudo wpa_supplicant -B -i wlan0 -c /etc/wpa_supplicant/wpa_supplicant.conf
# 3. Wait for connection (check with wpa_cli)
wpa_cli -i wlan0 status
# 4. Get IP address
sudo dhclient wlan0
# Or
sudo dhcpcd wlan0
Network Configuration
WPA/WPA2-Personal (PSK)
ASCII passphrase:
network={
ssid="MyWiFi"
psk="MyPassword123"
key_mgmt=WPA-PSK
priority=1
}
Pre-computed PSK (more secure):
# Generate PSK hash
wpa_passphrase "MyWiFi" "MyPassword123"
# Output:
network={
ssid="MyWiFi"
#psk="MyPassword123"
psk=0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef
}
In configuration file:
network={
ssid="MyWiFi"
psk=0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef
key_mgmt=WPA-PSK
}
WPA3-Personal (SAE)
network={
ssid="MyWiFi-WPA3"
psk="MyPassword123"
key_mgmt=SAE
ieee80211w=2 # Required for WPA3 (PMF)
}
Open Network (No Security)
network={
ssid="OpenWiFi"
key_mgmt=NONE
}
Hidden Network
network={
ssid="HiddenSSID"
scan_ssid=1 # Enable active scanning
psk="password"
key_mgmt=WPA-PSK
}
WEP (Deprecated)
network={
ssid="OldNetwork"
key_mgmt=NONE
wep_key0="1234567890"
wep_tx_keyidx=0
}
Multiple Networks with Priority
# Home network - highest priority
network={
ssid="HomeWiFi"
psk="homepassword"
priority=10
}
# Work network
network={
ssid="WorkWiFi"
psk="workpassword"
priority=5
}
# Coffee shop - lowest priority
network={
ssid="CoffeeShop"
key_mgmt=NONE
priority=1
}
BSSID-Specific Configuration
# Connect only to specific AP
network={
ssid="MyWiFi"
bssid=00:11:22:33:44:55
psk="password"
}
Command-Line Interface
wpa_cli - Control Interface
Basic commands:
# Show status
wpa_cli -i wlan0 status
# Scan for networks
wpa_cli -i wlan0 scan
wpa_cli -i wlan0 scan_results
# List configured networks
wpa_cli -i wlan0 list_networks
# Add network
wpa_cli -i wlan0 add_network
# Returns: 0 (network ID)
# Set network parameters
wpa_cli -i wlan0 set_network 0 ssid '"MyWiFi"'
wpa_cli -i wlan0 set_network 0 psk '"password"'
# Enable network
wpa_cli -i wlan0 enable_network 0
# Select network
wpa_cli -i wlan0 select_network 0
# Save configuration
wpa_cli -i wlan0 save_config
# Remove network
wpa_cli -i wlan0 remove_network 0
# Disconnect
wpa_cli -i wlan0 disconnect
# Reconnect
wpa_cli -i wlan0 reconnect
# Reassociate
wpa_cli -i wlan0 reassociate
Quick Connection
# One-liner to connect
wpa_cli -i wlan0 <<EOF
add_network
set_network 0 ssid "MyWiFi"
set_network 0 psk "password"
enable_network 0
save_config
quit
EOF
wpa_cli Interactive Mode
Starting Interactive Mode
wpa_cli -i wlan0
Interactive session:
wpa_cli v2.9
Copyright (c) 2004-2019, Jouni Malinen <j@w1.fi> and contributors
Interactive mode
> status
bssid=00:11:22:33:44:55
freq=2437
ssid=MyWiFi
id=0
mode=station
pairwise_cipher=CCMP
group_cipher=CCMP
key_mgmt=WPA2-PSK
wpa_state=COMPLETED
ip_address=192.168.1.100
address=aa:bb:cc:dd:ee:ff
> scan
OK
> scan_results
bssid / frequency / signal level / flags / ssid
00:11:22:33:44:55 2437 -45 [WPA2-PSK-CCMP][ESS] MyWiFi
aa:bb:cc:dd:ee:ff 2462 -67 [WPA2-PSK-CCMP][ESS] NeighborWiFi
> quit
Common Interactive Commands
status - Show connection status
scan - Trigger network scan
scan_results - Show scan results
list_networks - List configured networks
select_network <id> - Select network
enable_network <id> - Enable network
disable_network <id> - Disable network
remove_network <id> - Remove network
add_network - Add new network
set_network <id> <var> <value> - Set network parameter
save_config - Save configuration
disconnect - Disconnect from AP
reconnect - Reconnect to AP
reassociate - Force reassociation
terminate - Terminate wpa_supplicant
quit - Exit wpa_cli
Advanced Configuration
Band Selection (2.4 GHz vs 5 GHz)
network={
ssid="DualBandWiFi"
psk="password"
# Prefer 5 GHz
freq_list=5180 5200 5220 5240 5260 5280 5300 5320
}
Power Saving
# Global setting
# 0 = CAM (Constantly Awake Mode)
# 1 = PS mode (default)
# 2 = PS mode with max power saving
power_save=1
Roaming
network={
ssid="EnterpriseWiFi"
psk="password"
# Fast roaming (802.11r)
key_mgmt=FT-PSK
# Proactive key caching
proactive_key_caching=1
# BSS transition management
bss_transition=1
}
MAC Address Randomization
# Per-network MAC randomization
network={
ssid="PublicWiFi"
key_mgmt=NONE
mac_addr=1 # Random MAC per network
}
# Global setting
mac_addr=1
# 0 = Use permanent MAC
# 1 = Random MAC per network
# 2 = Random MAC per SSID
IPv6
# Disable IPv6 in wpa_supplicant
network={
ssid="MyWiFi"
psk="password"
disable_ipv6=1
}
Security Modes
WPA2-Enterprise (EAP-PEAP/MSCHAPv2)
network={
ssid="CorpWiFi"
key_mgmt=WPA-EAP
eap=PEAP
identity="username@domain.com"
password="userpassword"
phase2="auth=MSCHAPV2"
# Certificate verification
ca_cert="/etc/ssl/certs/ca-bundle.crt"
# Or skip verification (insecure!)
# ca_cert="/etc/ssl/certs/ca-certificates.crt"
}
WPA2-Enterprise (EAP-TLS with Certificates)
network={
ssid="SecureCorpWiFi"
key_mgmt=WPA-EAP
eap=TLS
identity="user@company.com"
# Client certificate
client_cert="/etc/wpa_supplicant/client.crt"
# Private key
private_key="/etc/wpa_supplicant/client.key"
# Private key password
private_key_passwd="keypassword"
# CA certificate
ca_cert="/etc/wpa_supplicant/ca.crt"
}
WPA2-Enterprise (EAP-TTLS/PAP)
network={
ssid="UniversityWiFi"
key_mgmt=WPA-EAP
eap=TTLS
identity="student@university.edu"
password="studentpass"
phase2="auth=PAP"
ca_cert="/etc/ssl/certs/ca-bundle.crt"
}
Eduroam Configuration
network={
ssid="eduroam"
key_mgmt=WPA-EAP
eap=PEAP
identity="username@institution.edu"
password="password"
phase2="auth=MSCHAPV2"
ca_cert="/etc/ssl/certs/ca-certificates.crt"
}
Enterprise WiFi (802.1X)
Certificate Management
# Download CA certificate
wget https://your-ca.com/ca.crt -O /etc/wpa_supplicant/ca.crt
# Set permissions
sudo chmod 600 /etc/wpa_supplicant/ca.crt
# Convert certificate format if needed
openssl x509 -inform DER -in ca.der -out ca.pem
Anonymous Identity (Privacy)
network={
ssid="CorpWiFi"
key_mgmt=WPA-EAP
eap=PEAP
# Anonymous outer identity
anonymous_identity="anonymous@company.com"
# Real identity (inner)
identity="realuser@company.com"
password="password"
phase2="auth=MSCHAPV2"
ca_cert="/etc/wpa_supplicant/ca.crt"
}
Domain Suffix Matching
network={
ssid="SecureWiFi"
key_mgmt=WPA-EAP
eap=PEAP
identity="user@company.com"
password="password"
phase2="auth=MSCHAPV2"
# Verify server domain
domain_suffix_match="radius.company.com"
ca_cert="/etc/wpa_supplicant/ca.crt"
}
P2P WiFi Direct
Enable WiFi Direct
# Global setting
ctrl_interface=/var/run/wpa_supplicant
p2p_disabled=0
device_name=MyDevice
device_type=1-0050F204-1
P2P Commands
# Start P2P mode
wpa_cli -i wlan0 p2p_find
# Stop search
wpa_cli -i wlan0 p2p_stop_find
# Connect to peer
wpa_cli -i wlan0 p2p_connect <peer_mac> pbc
# Group formation
wpa_cli -i wlan0 p2p_group_add
# Show peers
wpa_cli -i wlan0 p2p_peers
Troubleshooting
Check Status
# Interface status
ip link show wlan0
# wpa_supplicant status
wpa_cli -i wlan0 status
# Connection state
wpa_cli -i wlan0 status | grep wpa_state
# COMPLETED = connected
# SCANNING = scanning for networks
# ASSOCIATING = connecting
# DISCONNECTED = not connected
Debug Logging
# Run in foreground with debug
sudo killall wpa_supplicant
sudo wpa_supplicant -i wlan0 -c /etc/wpa_supplicant/wpa_supplicant.conf -dd
# Check system logs
sudo journalctl -u wpa_supplicant@wlan0 -f
# dmesg for driver issues
dmesg | grep -i wifi
dmesg | grep -i wlan
Common Issues
Authentication failure:
# Check password
wpa_passphrase "SSID" "password"
# Verify security mode
wpa_cli -i wlan0 scan_results
# Look for [WPA2-PSK-CCMP], [WPA3-SAE], etc.
# Check logs
sudo journalctl -u wpa_supplicant@wlan0 | grep -i "auth\|fail"
Cannot scan networks:
# Check if interface is up
sudo ip link set wlan0 up
# Check rfkill
rfkill list
sudo rfkill unblock wifi
# Manual scan
sudo iw dev wlan0 scan | grep SSID
Frequent disconnections:
# Check signal strength
watch -n 1 'iw dev wlan0 link'
# Disable power management
sudo iwconfig wlan0 power off
# Check logs for errors
sudo journalctl -u wpa_supplicant@wlan0 --since "10 minutes ago"
Driver issues:
# Check driver
lspci -k | grep -A 3 -i network
# Or for USB
lsusb
dmesg | grep -i firmware
# Reload driver
sudo modprobe -r <driver_name>
sudo modprobe <driver_name>
Integration with systemd
systemd Service
Per-interface service:
# Start service
sudo systemctl start wpa_supplicant@wlan0
# Enable on boot
sudo systemctl enable wpa_supplicant@wlan0
# Status
sudo systemctl status wpa_supplicant@wlan0
# Restart
sudo systemctl restart wpa_supplicant@wlan0
Service file: /lib/systemd/system/wpa_supplicant@.service
[Unit]
Description=WPA supplicant daemon (interface-specific version)
Requires=sys-subsystem-net-devices-%i.device
After=sys-subsystem-net-devices-%i.device
Before=network.target
Wants=network.target
[Service]
Type=simple
ExecStart=/sbin/wpa_supplicant -c/etc/wpa_supplicant/wpa_supplicant-%I.conf -i%I
[Install]
WantedBy=multi-user.target
networkd Integration
/etc/systemd/network/25-wireless.network:
[Match]
Name=wlan0
[Network]
DHCP=yes
Start services:
sudo systemctl enable systemd-networkd
sudo systemctl enable wpa_supplicant@wlan0
sudo systemctl start systemd-networkd
sudo systemctl start wpa_supplicant@wlan0
Best Practices
Security
- Use encrypted PSK:
# Generate PSK hash instead of plaintext
wpa_passphrase "SSID" "password" | sudo tee -a /etc/wpa_supplicant/wpa_supplicant.conf
- Secure configuration file:
sudo chmod 600 /etc/wpa_supplicant/wpa_supplicant.conf
- Use WPA3 when available:
network={
ssid="MyWiFi"
psk="password"
key_mgmt=SAE WPA-PSK # Try WPA3, fall back to WPA2
ieee80211w=1 # Optional PMF
}
- Verify certificates for Enterprise:
network={
ssid="CorpWiFi"
key_mgmt=WPA-EAP
ca_cert="/path/to/ca.crt"
domain_suffix_match="radius.company.com"
}
Performance
- Disable unnecessary features:
# Disable P2P if not needed
p2p_disabled=1
# Disable WPS
wps_disabled=1
- Optimize power saving:
# For performance (disable power save)
power_save=0
# For battery (enable power save)
power_save=2
- Fast roaming:
network={
ssid="EnterpriseWiFi"
key_mgmt=FT-PSK
proactive_key_caching=1
}
Reliability
- Network priority:
# Higher priority = preferred
network={
ssid="PrimaryWiFi"
priority=10
}
network={
ssid="BackupWiFi"
priority=5
}
- Automatic reconnection:
# systemd handles this automatically
sudo systemctl enable wpa_supplicant@wlan0
- Monitoring:
# Watch connection status
watch -n 2 'wpa_cli -i wlan0 status | grep -E "wpa_state|ssid|ip_address"'
Summary
wpa_supplicant is the standard WiFi client for Linux:
Basic workflow:
- Configure networks in
/etc/wpa_supplicant/wpa_supplicant.conf - Start:
sudo wpa_supplicant -B -i wlan0 -c /etc/wpa_supplicant/wpa_supplicant.conf - Manage:
wpa_cli -i wlan0 <command> - Get IP:
sudo dhclient wlan0
Key commands:
wpa_passphrase: Generate PSK hashwpa_supplicant: Main daemonwpa_cli: Control interfacesystemctl: Manage service
Common tasks:
- Connect to WPA2: Set
ssidandpsk - Enterprise WiFi: Configure EAP method
- Scan networks:
wpa_cli scan && wpa_cli scan_results - Debug: Run with
-ddflag
Resources:
hostapd
A comprehensive guide to hostapd, the IEEE 802.11 access point and authentication server for creating WiFi access points on Linux.
Table of Contents
- Overview
- Installation
- Basic Configuration
- Running hostapd
- Security Configurations
- Advanced Features
- Bridge Mode
- VLAN Support
- RADIUS Authentication
- 802.11n/ac/ax Configuration
- Monitoring and Management
- Troubleshooting
- Integration with systemd
- Best Practices
Overview
hostapd (host access point daemon) is a user-space daemon for access point and authentication servers. It implements IEEE 802.11 access point management, IEEE 802.1X/WPA/WPA2/WPA3/EAP authenticators, and RADIUS authentication server.
Key Features
- WiFi Access Point (AP) mode
- WPA/WPA2/WPA3-Personal and Enterprise
- Multiple SSIDs (up to 8 per radio)
- VLAN tagging
- 802.11n/ac/ax (WiFi 4/5/6)
- RADIUS authentication
- WPS (WiFi Protected Setup)
- Hotspot 2.0
- Dynamic VLAN assignment
Use Cases
- Create WiFi hotspot on Linux
- Home router/AP
- Enterprise wireless access point
- Captive portal
- Guest WiFi network
- Testing and development
Installation
Debian/Ubuntu
sudo apt-get update
sudo apt-get install hostapd
# Verify installation
hostapd -v
Fedora/RHEL/CentOS
sudo dnf install hostapd
# Or for older systems
sudo yum install hostapd
Arch Linux
sudo pacman -S hostapd
Build from Source
# Download
git clone git://w1.fi/srv/git/hostap.git
cd hostap/hostapd
# Configure
cp defconfig .config
# Edit .config to enable features
# Build
make
# Install
sudo make install
Basic Configuration
Minimal Configuration
File: /etc/hostapd/hostapd.conf
# Interface to use
interface=wlan0
# Driver (nl80211 is modern standard)
driver=nl80211
# WiFi network name
ssid=MyAccessPoint
# WiFi mode (a = 5GHz, g = 2.4GHz)
hw_mode=g
# WiFi channel
channel=6
# WPA2 settings
wpa=2
wpa_passphrase=MySecurePassword123
wpa_key_mgmt=WPA-PSK
wpa_pairwise=CCMP
Open Network (No Security)
interface=wlan0
driver=nl80211
ssid=OpenWiFi
hw_mode=g
channel=6
# No WPA settings = open network
Basic WPA2 Access Point
# Interface configuration
interface=wlan0
driver=nl80211
# SSID configuration
ssid=MyWiFi
utf8_ssid=1
# Hardware mode
hw_mode=g
channel=6
# IEEE 802.11n
ieee80211n=1
wmm_enabled=1
# Security: WPA2-Personal
auth_algs=1
wpa=2
wpa_key_mgmt=WPA-PSK
rsn_pairwise=CCMP
wpa_passphrase=SecurePassword123
# Logging
logger_syslog=-1
logger_syslog_level=2
logger_stdout=-1
logger_stdout_level=2
# Country code
country_code=US
# Max clients
max_num_sta=20
Running hostapd
Manual Start
# Check configuration syntax
sudo hostapd -t /etc/hostapd/hostapd.conf
# Run in foreground (for testing)
sudo hostapd /etc/hostapd/hostapd.conf
# Run in background
sudo hostapd -B /etc/hostapd/hostapd.conf
# With debug output
sudo hostapd -d /etc/hostapd/hostapd.conf
sudo hostapd -dd /etc/hostapd/hostapd.conf # More verbose
Complete Setup Script
#!/bin/bash
# setup-ap.sh
INTERFACE=wlan0
SSID="MyAccessPoint"
PASSWORD="MyPassword123"
CHANNEL=6
# Stop existing processes
sudo killall hostapd 2>/dev/null
sudo killall dnsmasq 2>/dev/null
# Configure interface
sudo ip link set $INTERFACE down
sudo ip addr flush dev $INTERFACE
sudo ip link set $INTERFACE up
sudo ip addr add 192.168.50.1/24 dev $INTERFACE
# Create hostapd config
cat > /tmp/hostapd.conf << EOF
interface=$INTERFACE
driver=nl80211
ssid=$SSID
hw_mode=g
channel=$CHANNEL
wmm_enabled=1
auth_algs=1
wpa=2
wpa_key_mgmt=WPA-PSK
wpa_pairwise=CCMP
wpa_passphrase=$PASSWORD
EOF
# Start hostapd
sudo hostapd -B /tmp/hostapd.conf
# Configure DHCP (dnsmasq)
sudo dnsmasq -C /dev/null \
--interface=$INTERFACE \
--dhcp-range=192.168.50.10,192.168.50.100,12h \
--no-daemon &
# Enable NAT
sudo sysctl net.ipv4.ip_forward=1
sudo iptables -t nat -A POSTROUTING -o eth0 -j MASQUERADE
sudo iptables -A FORWARD -i $INTERFACE -o eth0 -j ACCEPT
sudo iptables -A FORWARD -i eth0 -o $INTERFACE -m state --state RELATED,ESTABLISHED -j ACCEPT
echo "Access Point started: SSID=$SSID"
Security Configurations
WPA2-Personal (PSK)
interface=wlan0
ssid=SecureWiFi
# WPA2 with AES-CCMP
wpa=2
wpa_key_mgmt=WPA-PSK
rsn_pairwise=CCMP
wpa_passphrase=VerySecurePassword123
# Optional: require PMF (Protected Management Frames)
ieee80211w=1
WPA3-Personal (SAE)
interface=wlan0
ssid=WPA3WiFi
# WPA3-Personal (SAE)
wpa=2
wpa_key_mgmt=SAE
rsn_pairwise=CCMP
sae_password=SecureWPA3Password
# PMF is required for WPA3
ieee80211w=2
# SAE-specific settings
sae_pwe=2
sae_groups=19 20 21
WPA2/WPA3 Transition Mode
interface=wlan0
ssid=TransitionWiFi
# Both WPA2 and WPA3
wpa=2
wpa_key_mgmt=WPA-PSK SAE
rsn_pairwise=CCMP
# For WPA2
wpa_passphrase=Password123
# For WPA3
sae_password=Password123
# PMF optional (required for WPA3 clients)
ieee80211w=1
WPA2-Enterprise (802.1X)
interface=wlan0
ssid=EnterpriseWiFi
# WPA2-Enterprise
wpa=2
wpa_key_mgmt=WPA-EAP
rsn_pairwise=CCMP
# IEEE 802.1X
ieee8021x=1
# RADIUS server configuration
auth_server_addr=192.168.1.10
auth_server_port=1812
auth_server_shared_secret=radiussecret
# Optional: Accounting server
acct_server_addr=192.168.1.10
acct_server_port=1813
acct_server_shared_secret=radiussecret
# EAP configuration
eap_server=0
eapol_key_index_workaround=0
Hidden SSID
interface=wlan0
ssid=HiddenNetwork
# Hide SSID in beacons
ignore_broadcast_ssid=1
wpa=2
wpa_passphrase=password
MAC Address Filtering
interface=wlan0
ssid=FilteredWiFi
# MAC address ACL
macaddr_acl=1
# 0 = accept unless in deny list
# 1 = deny unless in accept list
# 2 = use external RADIUS
# Accept list
accept_mac_file=/etc/hostapd/accept.mac
# Deny list (if macaddr_acl=0)
deny_mac_file=/etc/hostapd/deny.mac
wpa=2
wpa_passphrase=password
/etc/hostapd/accept.mac:
00:11:22:33:44:55
aa:bb:cc:dd:ee:ff
Advanced Features
Multiple SSIDs (Multi-BSS)
Main configuration /etc/hostapd/hostapd.conf:
# Primary interface
interface=wlan0
driver=nl80211
ctrl_interface=/var/run/hostapd
# Channel configuration (shared by all BSS)
hw_mode=g
channel=6
ieee80211n=1
# Primary SSID
ssid=MainWiFi
wpa=2
wpa_passphrase=MainPassword
# Multiple BSSs
bss=wlan0_0
ssid=GuestWiFi
wpa=2
wpa_passphrase=GuestPassword
# Isolate guest clients
ap_isolate=1
bss=wlan0_1
ssid=IoTWiFi
wpa=2
wpa_passphrase=IoTPassword
Client Isolation
interface=wlan0
ssid=IsolatedWiFi
# Prevent clients from communicating with each other
ap_isolate=1
wpa=2
wpa_passphrase=password
5 GHz Configuration
interface=wlan0
driver=nl80211
# 5 GHz band
hw_mode=a
channel=36
# Channel width
# HT40+ = 40 MHz (channels 36,40)
# VHT80 = 80 MHz
# VHT160 = 160 MHz
vht_oper_chwidth=1
vht_oper_centr_freq_seg0_idx=42
ssid=5GHz_WiFi
wpa=2
wpa_passphrase=password
WPS (WiFi Protected Setup)
interface=wlan0
ssid=WPS_WiFi
wpa=2
wpa_passphrase=password
# Enable WPS
wps_state=2
eap_server=1
# Device information
device_name=Linux_AP
manufacturer=OpenSource
model_name=hostapd
model_number=1.0
config_methods=push_button keypad
# UUID (generate with uuidgen)
uuid=12345678-9abc-def0-1234-56789abcdef0
Trigger WPS:
# Push button
hostapd_cli wps_pbc
# PIN method
hostapd_cli wps_pin any 12345670
Bridge Mode
Bridge Configuration
# Create bridge
sudo ip link add name br0 type bridge
sudo ip link set br0 up
# Add Ethernet to bridge
sudo ip link set eth0 master br0
# Configure bridge IP
sudo ip addr add 192.168.1.1/24 dev br0
hostapd.conf:
interface=wlan0
bridge=br0
driver=nl80211
ssid=BridgedWiFi
hw_mode=g
channel=6
wpa=2
wpa_passphrase=password
Complete Bridge Setup
#!/bin/bash
# bridge-ap.sh
WLAN=wlan0
ETH=eth0
BRIDGE=br0
# Create bridge
sudo ip link add name $BRIDGE type bridge
sudo ip link set $BRIDGE up
# Add Ethernet
sudo ip link set $ETH down
sudo ip addr flush dev $ETH
sudo ip link set $ETH master $BRIDGE
sudo ip link set $ETH up
# Configure bridge
sudo ip addr add 192.168.1.1/24 dev $BRIDGE
# hostapd config with bridge
cat > /tmp/hostapd-bridge.conf << EOF
interface=$WLAN
bridge=$BRIDGE
driver=nl80211
ssid=BridgedAP
hw_mode=g
channel=6
wpa=2
wpa_passphrase=password
EOF
# Start hostapd
sudo hostapd -B /tmp/hostapd-bridge.conf
# Start DHCP server on bridge
sudo dnsmasq --interface=$BRIDGE \
--dhcp-range=192.168.1.100,192.168.1.200,12h
VLAN Support
Static VLAN Assignment
hostapd.conf:
interface=wlan0
ssid=MultiVLAN_WiFi
wpa=2
wpa_passphrase=password
# Enable dynamic VLAN
dynamic_vlan=1
vlan_file=/etc/hostapd/vlan.conf
/etc/hostapd/vlan.conf:
# VLAN_ID VLAN_IFNAME
1 wlan0.1
10 wlan0.10
20 wlan0.20
VLAN with RADIUS
interface=wlan0
ssid=Enterprise_VLAN
wpa=2
wpa_key_mgmt=WPA-EAP
ieee8021x=1
# RADIUS server
auth_server_addr=192.168.1.10
auth_server_port=1812
auth_server_shared_secret=secret
# Dynamic VLAN from RADIUS
dynamic_vlan=1
vlan_naming=1
RADIUS Authentication
Internal EAP Server
interface=wlan0
ssid=InternalEAP_WiFi
# Use hostapd's internal EAP server
ieee8021x=1
eap_server=1
eap_user_file=/etc/hostapd/hostapd.eap_user
ca_cert=/etc/hostapd/ca.pem
server_cert=/etc/hostapd/server.pem
private_key=/etc/hostapd/server-key.pem
private_key_passwd=keypassword
wpa=2
wpa_key_mgmt=WPA-EAP
rsn_pairwise=CCMP
/etc/hostapd/hostapd.eap_user:
# Phase 1 authentication
* PEAP
"user1" MSCHAPV2 "password1" [2]
"user2" MSCHAPV2 "password2" [2]
# TLS
"client1" TLS
External RADIUS Server
interface=wlan0
ssid=RADIUS_WiFi
wpa=2
wpa_key_mgmt=WPA-EAP
ieee8021x=1
# Primary RADIUS server
auth_server_addr=192.168.1.10
auth_server_port=1812
auth_server_shared_secret=sharedsecret
# Backup RADIUS server
auth_server_addr=192.168.1.11
auth_server_port=1812
auth_server_shared_secret=sharedsecret
# Accounting
acct_server_addr=192.168.1.10
acct_server_port=1813
acct_server_shared_secret=sharedsecret
# Disable internal EAP
eap_server=0
802.11n/ac/ax Configuration
802.11n (WiFi 4) - 2.4 GHz
interface=wlan0
ssid=N_WiFi_2_4GHz
hw_mode=g
channel=6
# Enable 802.11n
ieee80211n=1
wmm_enabled=1
# HT capabilities
ht_capab=[HT40+][SHORT-GI-20][SHORT-GI-40][DSSS_CCK-40]
wpa=2
wpa_passphrase=password
802.11n (WiFi 4) - 5 GHz
interface=wlan0
ssid=N_WiFi_5GHz
hw_mode=a
channel=36
ieee80211n=1
wmm_enabled=1
# 40 MHz channel
ht_capab=[HT40+][SHORT-GI-20][SHORT-GI-40]
wpa=2
wpa_passphrase=password
802.11ac (WiFi 5)
interface=wlan0
ssid=AC_WiFi
hw_mode=a
channel=36
# 802.11n required
ieee80211n=1
ht_capab=[HT40+][SHORT-GI-20][SHORT-GI-40]
# 802.11ac
ieee80211ac=1
vht_capab=[MAX-MPDU-11454][SHORT-GI-80][TX-STBC-2BY1][RX-STBC-1]
# 80 MHz channel
vht_oper_chwidth=1
vht_oper_centr_freq_seg0_idx=42
wmm_enabled=1
wpa=2
wpa_passphrase=password
802.11ax (WiFi 6)
interface=wlan0
ssid=AX_WiFi
hw_mode=a
channel=36
# 802.11n
ieee80211n=1
ht_capab=[HT40+][SHORT-GI-20][SHORT-GI-40]
# 802.11ac
ieee80211ac=1
vht_oper_chwidth=1
vht_oper_centr_freq_seg0_idx=42
# 802.11ax
ieee80211ax=1
he_su_beamformer=1
he_su_beamformee=1
he_mu_beamformer=1
wmm_enabled=1
wpa=3 # WPA3
wpa_key_mgmt=SAE
sae_password=password
ieee80211w=2
Monitoring and Management
hostapd_cli
# Connect to running hostapd
hostapd_cli
# Or specify interface
hostapd_cli -i wlan0
# Get status
hostapd_cli status
# List connected stations
hostapd_cli all_sta
# Disconnect a station
hostapd_cli disassociate <MAC>
# Reload configuration
hostapd_cli reload
# Enable/disable
hostapd_cli disable
hostapd_cli enable
Monitor Connected Clients
# List all stations
hostapd_cli all_sta
# Detailed station info
hostapd_cli sta <MAC_ADDRESS>
# Example output:
# dot11RSNAStatsSTAAddress=aa:bb:cc:dd:ee:ff
# dot11RSNAStatsVersion=1
# dot11RSNAStatsSelectedPairwiseCipher=00-0f-ac-4
# dot11RSNAStatsTKIPLocalMICFailures=0
# flags=[AUTH][ASSOC][AUTHORIZED]
Signal Strength
# Show signal strength for connected clients
for mac in $(hostapd_cli all_sta | grep ^[0-9a-f] | cut -d' ' -f1); do
echo "Station: $mac"
hostapd_cli sta $mac | grep signal
done
Troubleshooting
Check Configuration
# Test configuration syntax
sudo hostapd -t /etc/hostapd/hostapd.conf
# Expected output: Configuration file: /etc/hostapd/hostapd.conf
Debug Mode
# Run in foreground with debug
sudo systemctl stop hostapd
sudo hostapd -d /etc/hostapd/hostapd.conf
# More verbose
sudo hostapd -dd /etc/hostapd/hostapd.conf
Common Issues
Cannot start AP - device busy:
# Check if NetworkManager is controlling interface
nmcli device status
# Unmanage interface
sudo nmcli device set wlan0 managed no
# Or disable NetworkManager for interface
# /etc/NetworkManager/NetworkManager.conf
[keyfile]
unmanaged-devices=mac:aa:bb:cc:dd:ee:ff
sudo systemctl restart NetworkManager
Channel not available:
# Check supported channels
iw list | grep -A 20 "Frequencies:"
# Check regulatory domain
iw reg get
# Set country code
sudo iw reg set US
# Or in hostapd.conf
country_code=US
ieee80211d=1
Interface doesn't support AP mode:
# Check supported modes
iw list | grep -A 10 "Supported interface modes:"
# Should show:
# * AP
# * AP/VLAN
# If not present, hardware doesn't support AP mode
Authentication failures:
# Check logs
sudo journalctl -u hostapd -f
# Common causes:
# 1. Wrong password
# 2. Incompatible security settings
# 3. Client doesn't support WPA3
# 4. PMF issues
# Try WPA2 for compatibility
wpa=2
wpa_key_mgmt=WPA-PSK
No DHCP addresses:
# Check if DHCP server is running
ps aux | grep dnsmasq
# Check interface has IP
ip addr show wlan0
# Test DHCP manually
sudo dnsmasq --no-daemon --interface=wlan0 \
--dhcp-range=192.168.50.10,192.168.50.100,12h \
--log-queries
Integration with systemd
systemd Service
# Enable and start
sudo systemctl unmask hostapd
sudo systemctl enable hostapd
sudo systemctl start hostapd
# Status
sudo systemctl status hostapd
# Logs
sudo journalctl -u hostapd -f
Custom Service File
/etc/systemd/system/hostapd.service:
[Unit]
Description=Access point and authentication server
After=network.target
[Service]
Type=forking
PIDFile=/var/run/hostapd.pid
ExecStart=/usr/sbin/hostapd -B -P /var/run/hostapd.pid /etc/hostapd/hostapd.conf
ExecReload=/bin/kill -HUP $MAINPID
Restart=on-failure
RestartSec=5
[Install]
WantedBy=multi-user.target
Configuration File Location
/etc/default/hostapd:
DAEMON_CONF="/etc/hostapd/hostapd.conf"
Best Practices
Security
- Use WPA3 when possible:
wpa=2
wpa_key_mgmt=SAE
ieee80211w=2
- Strong passwords:
# Minimum 12 characters
wpa_passphrase=MyVerySecurePassword123!
- Disable WPS in production:
wps_state=0
- Enable PMF:
ieee80211w=1 # Optional
# or
ieee80211w=2 # Required (WPA3)
- Guest network isolation:
bss=wlan0_0
ssid=Guest
ap_isolate=1
Performance
- Use 5 GHz for better performance:
hw_mode=a
channel=36
- Enable 802.11n/ac:
ieee80211n=1
ieee80211ac=1
wmm_enabled=1
- Choose non-overlapping channels:
2.4 GHz: 1, 6, 11
5 GHz: Many options (36, 40, 44, 48...)
- Limit max clients:
max_num_sta=50
Reliability
- Set country code:
country_code=US
ieee80211d=1
ieee80211h=1
- Enable logging:
logger_syslog=-1
logger_syslog_level=2
- Automatic restart:
sudo systemctl enable hostapd
Summary
hostapd creates WiFi access points on Linux:
Basic workflow:
- Configure
/etc/hostapd/hostapd.conf - Start:
sudo hostapd /etc/hostapd/hostapd.conf - Configure DHCP server (dnsmasq)
- Enable IP forwarding and NAT (for internet sharing)
Minimal config:
interface=wlan0
ssid=MyWiFi
channel=6
wpa=2
wpa_passphrase=password
Essential commands:
hostapd -t: Test configurationhostapd_cli: Control running APsystemctl start hostapd: Start service
Common tasks:
- WPA2 AP: Configure
wpa=2andwpa_passphrase - WPA3 AP: Use
key_mgmt=SAEandieee80211w=2 - Guest network: Use multi-BSS with
ap_isolate=1 - Bridge mode: Set
bridge=br0
Resources:
Embedded Systems
Comprehensive guide to embedded systems development, microcontrollers, and hardware interfacing.
Table of Contents
- Introduction
- Development Platforms
- Core Concepts
- Communication Protocols
- Peripheral Interfaces
- Getting Started
Introduction
Embedded systems are specialized computing systems designed to perform dedicated functions within larger mechanical or electrical systems. They combine hardware and software to control devices and interact with the physical world.
Key Characteristics
- Real-time Operation: Deterministic response to events
- Resource Constraints: Limited memory, processing power, and energy
- Reliability: Must operate continuously for extended periods
- Hardware Integration: Direct interaction with sensors and actuators
- Application-Specific: Optimized for particular tasks
Architecture Overview
┌─────────────────────────────────────────┐
│ Embedded System │
├─────────────────────────────────────────┤
│ Application Layer │
│ ├─ User Code │
│ └─ Libraries & Frameworks │
├─────────────────────────────────────────┤
│ HAL/Drivers │
│ ├─ Peripheral Drivers │
│ └─ Hardware Abstraction Layer │
├─────────────────────────────────────────┤
│ Microcontroller/Processor │
│ ├─ CPU Core (ARM, AVR, RISC-V) │
│ ├─ Memory (Flash, RAM, EEPROM) │
│ ├─ Peripherals (GPIO, UART, SPI...) │
│ └─ Clock & Power Management │
├─────────────────────────────────────────┤
│ Hardware │
│ ├─ Sensors │
│ ├─ Actuators │
│ └─ External Interfaces │
└─────────────────────────────────────────┘
Development Platforms
Microcontroller Platforms
| Platform | Processor | Clock | Memory | Use Cases |
|---|---|---|---|---|
| Arduino | AVR/ARM | 16-84 MHz | 2KB-256KB RAM | Prototyping, education, hobbyist projects |
| ESP32 | Xtensa/RISC-V | 160-240 MHz | 520KB RAM | IoT, WiFi/BLE projects |
| STM32 | ARM Cortex-M | 48-550 MHz | 32KB-2MB RAM | Professional, industrial applications |
| AVR | AVR | 1-20 MHz | 512B-16KB RAM | Low-power, bare-metal programming |
| Raspberry Pi | ARM Cortex-A | 700MHz-2.4GHz | 512MB-8GB RAM | Linux-based, complex applications |
Comparison Matrix
Complexity/Capability
^
|
RPi | ┌──────────┐
| │ │
STM32| │ │ ┌──────┐
| │ │ │ │
ESP32| │ │ │ │ ┌─────┐
| │ │ │ │ │ │
ARD | │ │ │ │ │ │ ┌────┐
| │ │ │ │ │ │ │ │
AVR | │ │ │ │ │ │ │ │
| └──────────┴────┴──────┴──┴─────┴──┴────┘
+──────────────────────────────────────────> Cost
Low High
Core Concepts
Memory Architecture
Flash Memory (Program Storage)
- Stores program code and constant data
- Non-volatile (persists without power)
- Typically 8KB to several MB
- Limited write cycles (10K-100K)
SRAM (Runtime Memory)
- Stores variables and stack during execution
- Volatile (lost when power removed)
- Fast access, limited size
- Critical resource in embedded systems
EEPROM (Persistent Data)
- Stores configuration and calibration data
- Non-volatile, byte-addressable
- Limited write cycles but higher than Flash
- Slower than SRAM
Memory Map Example (ATmega328P):
┌────────────────┐ 0x0000
│ Flash (32KB) │
│ Program Code │
├────────────────┤ 0x7FFF
│ SRAM (2KB) │
│ Variables │
│ Stack │
├────────────────┤ 0x08FF
│ EEPROM (1KB) │
│ Persistent │
└────────────────┘ 0x03FF
Power Management
Operating Modes
- Active Mode: Full operation, highest power consumption
- Idle Mode: CPU stopped, peripherals running
- Sleep Mode: Most peripherals disabled
- Deep Sleep: Minimal power, wake on interrupt only
Power Saving Techniques
// Example: AVR Sleep Mode
#include <avr/sleep.h>
#include <avr/power.h>
void enterSleepMode() {
set_sleep_mode(SLEEP_MODE_PWR_DOWN);
sleep_enable();
// Disable unnecessary peripherals
power_adc_disable();
power_spi_disable();
power_timer0_disable();
sleep_mode(); // Enter sleep
// Wake up here after interrupt
sleep_disable();
// Re-enable peripherals
power_all_enable();
}
Interrupt-Driven Programming
Interrupts allow the processor to respond to events immediately without polling.
// Example: External Interrupt
volatile bool buttonPressed = false;
// Interrupt Service Routine (ISR)
void EXTI0_IRQHandler(void) {
if (EXTI->PR & EXTI_PR_PR0) {
buttonPressed = true;
EXTI->PR |= EXTI_PR_PR0; // Clear interrupt flag
}
}
int main(void) {
// Setup interrupt
RCC->APB2ENR |= RCC_APB2ENR_IOPAEN;
GPIOA->CRL &= ~GPIO_CRL_CNF0;
GPIOA->CRL |= GPIO_CRL_CNF0_1; // Input with pull-up
AFIO->EXTICR[0] = AFIO_EXTICR1_EXTI0_PA;
EXTI->IMR |= EXTI_IMR_MR0;
EXTI->FTSR |= EXTI_FTSR_TR0; // Falling edge
NVIC_EnableIRQ(EXTI0_IRQn);
while (1) {
if (buttonPressed) {
// Handle button press
buttonPressed = false;
}
// Main loop continues
}
}
Communication Protocols
Serial Protocols Overview
| Protocol | Type | Speed | Wires | Use Case |
|---|---|---|---|---|
| UART | Asynchronous | Up to 1 Mbps | 2 (TX/RX) | Debug, GPS, Bluetooth modules |
| SPI | Synchronous | Up to 50 Mbps | 4+ (MOSI/MISO/SCK/CS) | SD cards, displays, high-speed sensors |
| I2C | Synchronous | 100-400 kHz | 2 (SDA/SCL) | Sensors, RTCs, EEPROMs |
| CAN | Differential | Up to 1 Mbps | 2 (CAN_H/CAN_L) | Automotive, industrial |
| USB | Differential | 1.5-480 Mbps | 2 (D+/D-) | PC interface, peripherals |
Protocol Comparison
Speed (Mbps)
^
|
100 | ┌─── USB 2.0
| │
50 | ┌─── SPI│
| │ │
10 | │ │
| │ │
1 | ┌─ UART │ │
| │ │ │ │
0.1 | │ I2C│ │ │
| │ │ │ │
└──┴────┴───┴───────┴────────> Complexity
Low High
Peripheral Interfaces
Digital I/O (GPIO)
- General Purpose Input/Output pins
- Digital HIGH/LOW states
- Input modes: floating, pull-up, pull-down
- Output modes: push-pull, open-drain
Analog Interfaces
- ADC: Convert analog voltages to digital values
- DAC: Convert digital values to analog voltages
- PWM: Pulse Width Modulation for analog-like output
Timing and Control
- Timers: Hardware timers for precise timing
- Interrupts: Event-driven programming
- Watchdog: System reliability and reset
Specialized Interfaces
Getting Started
Development Environment Setup
1. Choose Your Platform
Start with Arduino for beginners, or jump to STM32/ESP32 for more advanced projects.
2. Install Tools
For Arduino:
# Download Arduino IDE from arduino.cc
# Or use Arduino CLI
curl -fsSL https://raw.githubusercontent.com/arduino/arduino-cli/master/install.sh | sh
For STM32:
# Install STM32CubeIDE
# Download from st.com
# Or use PlatformIO
pip install platformio
For ESP32:
# Add ESP32 to Arduino IDE
# Or use ESP-IDF
git clone --recursive https://github.com/espressif/esp-idf.git
cd esp-idf
./install.sh
3. Hardware Setup
Minimum Requirements:
- Development board (Arduino Uno, ESP32, STM32 Nucleo, etc.)
- USB cable
- Computer with IDE installed
- Optional: Breadboard, jumper wires, components
Development Kit:
Essential Components:
├─ Microcontroller board
├─ USB cable
├─ Breadboard
├─ Jumper wires (male-male, male-female)
├─ LEDs and resistors (220Ω)
├─ Push buttons
├─ Potentiometer (10kΩ)
└─ Multimeter
Sensors (Optional):
├─ Temperature (DHT11/22, DS18B20)
├─ Distance (HC-SR04 ultrasonic)
├─ Light (LDR, BH1750)
└─ Motion (PIR, MPU6050)
First Program: Blink LED
Arduino Version
// Blink LED on pin 13
void setup() {
pinMode(13, OUTPUT);
}
void loop() {
digitalWrite(13, HIGH);
delay(1000);
digitalWrite(13, LOW);
delay(1000);
}
STM32 HAL Version
#include "stm32f4xx_hal.h"
int main(void) {
HAL_Init();
SystemClock_Config();
__HAL_RCC_GPIOA_CLK_ENABLE();
GPIO_InitTypeDef GPIO_InitStruct = {0};
GPIO_InitStruct.Pin = GPIO_PIN_5;
GPIO_InitStruct.Mode = GPIO_MODE_OUTPUT_PP;
GPIO_InitStruct.Pull = GPIO_NOPULL;
GPIO_InitStruct.Speed = GPIO_SPEED_FREQ_LOW;
HAL_GPIO_Init(GPIOA, &GPIO_InitStruct);
while (1) {
HAL_GPIO_TogglePin(GPIOA, GPIO_PIN_5);
HAL_Delay(1000);
}
}
Bare Metal AVR Version
#include <avr/io.h>
#include <util/delay.h>
int main(void) {
DDRB |= (1 << DDB5); // Set PB5 as output
while (1) {
PORTB |= (1 << PORTB5); // LED on
_delay_ms(1000);
PORTB &= ~(1 << PORTB5); // LED off
_delay_ms(1000);
}
return 0;
}
Learning Path
┌─────────────────────────────────────────────┐
│ Level 1: Fundamentals │
├─────────────────────────────────────────────┤
│ • Digital I/O (LED, button) │
│ • Analog input (ADC, potentiometer) │
│ • PWM (LED brightness, motor speed) │
│ • Serial communication (UART debug) │
└─────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────┐
│ Level 2: Intermediate │
├─────────────────────────────────────────────┤
│ • Timers and interrupts │
│ • I2C sensors (temperature, accelerometer) │
│ • SPI devices (SD card, display) │
│ • State machines │
└─────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────┐
│ Level 3: Advanced │
├─────────────────────────────────────────────┤
│ • DMA transfers │
│ • RTOS (FreeRTOS) │
│ • Low-power optimization │
│ • Bootloaders and OTA updates │
└─────────────────────────────────────────────┘
Common Project Ideas
-
Beginner Projects
- LED blink and patterns
- Button-controlled LED
- Temperature monitor
- Light-sensitive nightlight
-
Intermediate Projects
- Digital thermometer with display
- Motor speed controller
- Distance measurement system
- Data logger with SD card
-
Advanced Projects
- Weather station with WiFi
- Robot controller
- Home automation system
- Wireless sensor network
Best Practices
Code Organization
// Good structure
├─ src/
│ ├─ main.c
│ ├─ drivers/
│ │ ├─ sensor.c
│ │ └─ display.c
│ └─ app/
│ ├─ control.c
│ └─ config.c
├─ inc/
│ ├─ sensor.h
│ ├─ display.h
│ └─ config.h
└─ Makefile
Design Principles
- Keep ISRs Short: Minimal processing in interrupt handlers
- Use Volatile: For variables modified by ISRs
- Debounce Inputs: Software or hardware debouncing for buttons
- Watchdog Timer: Implement system recovery
- Power Efficiency: Use sleep modes when idle
- Error Handling: Check return values and handle failures
- Documentation: Comment complex logic and register operations
Debugging Techniques
// UART debug output
void debug_print(const char* msg) {
uart_send_string(msg);
}
// LED status indicators
#define LED_ERROR GPIO_PIN_0
#define LED_OK GPIO_PIN_1
#define LED_BUSY GPIO_PIN_2
// Assert macro
#define ASSERT(expr) \
if (!(expr)) { \
debug_print("Assert failed: " #expr); \
while(1); // Halt \
}
Resources
Documentation
- Platform-specific datasheets and reference manuals
- Peripheral application notes
- HAL/LL library documentation
Tools
- Oscilloscope: Analyze signals and timing
- Logic Analyzer: Debug digital protocols
- Multimeter: Measure voltages and continuity
- Debugger: JTAG/SWD for step-through debugging
Communities
- Arduino Forum
- STM32 Community
- ESP32 Forum
- Reddit: r/embedded, r/arduino
- Stack Overflow: Embedded tag
See Also
AVR Microcontrollers
Comprehensive guide to AVR microcontroller programming with register-level control and bare-metal development.
Table of Contents
- Introduction
- AVR Architecture
- Development Setup
- Register Programming
- GPIO Control
- Timers and Counters
- Interrupts
- Communication Protocols
- Advanced Topics
Introduction
AVR is a family of 8-bit RISC microcontrollers developed by Atmel (now Microchip). They are widely used in Arduino boards and embedded systems due to their simplicity, efficiency, and low cost.
Key Features
- 8-bit RISC Architecture: Harvard architecture with separate program and data memory
- Clock Speed: 1-20 MHz
- Flash Memory: 2-256 KB
- SRAM: 128 bytes - 16 KB
- EEPROM: 64 bytes - 4 KB
- Peripherals: GPIO, Timers, ADC, UART, SPI, I2C
- Power Efficient: Multiple sleep modes
- Price: $1-5
Popular AVR Microcontrollers
| MCU | Flash | RAM | EEPROM | GPIO | ADC | Timers | Package | Use Case |
|---|---|---|---|---|---|---|---|---|
| ATtiny13 | 1 KB | 64 B | 64 B | 6 | 4 | 1 | 8-pin | Ultra-small projects |
| ATtiny85 | 8 KB | 512 B | 512 B | 6 | 4 | 2 | 8-pin | Small projects |
| ATmega8 | 8 KB | 1 KB | 512 B | 23 | 6 | 3 | 28-pin | Entry level |
| ATmega328P | 32 KB | 2 KB | 1 KB | 23 | 6 | 3 | 28-pin | Arduino Uno |
| ATmega2560 | 256 KB | 8 KB | 4 KB | 86 | 16 | 6 | 100-pin | Arduino Mega |
AVR Architecture
Memory Organization
┌──────────────────────────────────────┐
│ AVR Memory Map │
├──────────────────────────────────────┤
│ Program Memory (Flash) │
│ ┌────────────────────────────────┐ │
│ │ 0x0000: Interrupt Vectors │ │
│ │ 0x0034: Program Code │ │
│ │ ... │ │
│ │ End: Bootloader (optional) │ │
│ └────────────────────────────────┘ │
├──────────────────────────────────────┤
│ Data Memory (SRAM) │
│ ┌────────────────────────────────┐ │
│ │ 0x0000-0x001F: Registers (R0-R31)│
│ │ 0x0020-0x005F: I/O Registers │ │
│ │ 0x0060-0x00FF: Extended I/O │ │
│ │ 0x0100-... : SRAM │ │
│ │ ... : Stack (grows ↓) │ │
│ └────────────────────────────────┘ │
├──────────────────────────────────────┤
│ EEPROM (Non-volatile) │
│ ┌────────────────────────────────┐ │
│ │ 0x0000: User data storage │ │
│ │ ... │ │
│ └────────────────────────────────┘ │
└──────────────────────────────────────┘
Registers
General Purpose Registers
R0-R31: 32 general-purpose 8-bit registers
R26-R27: X pointer (XL, XH)
R28-R29: Y pointer (YL, YH)
R30-R31: Z pointer (ZL, ZH)
Status Register (SREG)
Bit 7: I - Global Interrupt Enable
Bit 6: T - Transfer bit
Bit 5: H - Half Carry Flag
Bit 4: S - Sign Flag
Bit 3: V - Overflow Flag
Bit 2: N - Negative Flag
Bit 1: Z - Zero Flag
Bit 0: C - Carry Flag
ATmega328P Pinout
ATmega328P (DIP-28)
┌───∪───┐
RESET 1 ─┤ ├─ 28 PC5/ADC5/SCL
RXD/D0 2 ─┤ ├─ 27 PC4/ADC4/SDA
TXD/D1 3 ─┤ ├─ 26 PC3/ADC3
INT0/D2 4 ─┤ ├─ 25 PC2/ADC2
INT1/D3 5 ─┤ ├─ 24 PC1/ADC1
D4 6 ─┤ ├─ 23 PC0/ADC0
VCC 7 ─┤ ├─ 22 GND
GND 8 ─┤ ├─ 21 AREF
XTAL1 9 ─┤ ├─ 20 AVCC
XTAL2 10 ─┤ ├─ 19 PB5/SCK
D5 11 ─┤ ├─ 18 PB4/MISO
D6 12 ─┤ ├─ 17 PB3/MOSI
D7 13 ─┤ ├─ 16 PB2/SS
D8 14 ─┤ ├─ 15 PB1/OC1A
└───────┘
GPIO Ports:
Port B (PB0-PB5): Digital I/O, SPI
Port C (PC0-PC5): Analog input (ADC), I2C
Port D (PD0-PD7): Digital I/O, UART, Interrupts
Development Setup
AVR-GCC Toolchain
# Install AVR tools (Ubuntu/Debian)
sudo apt install gcc-avr avr-libc avrdude
# Install on Arch Linux
sudo pacman -S avr-gcc avr-libc avrdude
# Install on macOS
brew install avr-gcc avr-libc avrdude
# Verify installation
avr-gcc --version
avrdude -v
Project Structure
project/
├── main.c
├── Makefile
└── README.md
Makefile Template
# AVR Makefile
MCU = atmega328p
F_CPU = 16000000UL
BAUD = 9600
CC = avr-gcc
OBJCOPY = avr-objcopy
OBJDUMP = avr-objdump
SIZE = avr-size
TARGET = main
SRC = main.c
CFLAGS = -mmcu=$(MCU) -DF_CPU=$(F_CPU) -DBAUD=$(BAUD)
CFLAGS += -Os -Wall -Wextra -std=c99
# Programmer settings
PROGRAMMER = arduino
PORT = /dev/ttyUSB0
all: $(TARGET).hex
$(TARGET).elf: $(SRC)
$(CC) $(CFLAGS) -o $@ $^
$(SIZE) $@
$(TARGET).hex: $(TARGET).elf
$(OBJCOPY) -O ihex -R .eeprom $< $@
flash: $(TARGET).hex
avrdude -c $(PROGRAMMER) -p $(MCU) -P $(PORT) -U flash:w:$<
clean:
rm -f $(TARGET).elf $(TARGET).hex
.PHONY: all flash clean
Compiling and Flashing
# Compile
make
# Flash to device
make flash
# Clean build files
make clean
# Manual commands
avr-gcc -mmcu=atmega328p -DF_CPU=16000000UL -Os -o main.elf main.c
avr-objcopy -O ihex -R .eeprom main.elf main.hex
avrdude -c arduino -p atmega328p -P /dev/ttyUSB0 -U flash:w:main.hex
Register Programming
Understanding Registers
AVR programming requires direct manipulation of hardware registers. Each peripheral has associated registers for control and data.
Register Operations
#include <avr/io.h>
/* Set bit (set to 1) */
PORTB |= (1 << PB5);
/* Clear bit (set to 0) */
PORTB &= ~(1 << PB5);
/* Toggle bit */
PORTB ^= (1 << PB5);
/* Check bit */
if (PIND & (1 << PD2)) {
// Bit is set
}
/* Set multiple bits */
PORTB |= (1 << PB0) | (1 << PB1) | (1 << PB2);
/* Clear multiple bits */
PORTB &= ~((1 << PB0) | (1 << PB1));
/* Write entire register */
PORTB = 0b10101010;
GPIO Control
Port Registers
Each GPIO port has three registers:
- DDRx: Data Direction Register (1 = Output, 0 = Input)
- PORTx: Port Output Register (Output value or pull-up enable)
- PINx: Port Input Register (Read input state)
Basic GPIO Example
#include <avr/io.h>
#include <util/delay.h>
int main(void) {
/* Set PB5 (Arduino pin 13) as output */
DDRB |= (1 << DDB5);
/* Main loop */
while (1) {
/* Turn LED on */
PORTB |= (1 << PORTB5);
_delay_ms(1000);
/* Turn LED off */
PORTB &= ~(1 << PORTB5);
_delay_ms(1000);
}
return 0;
}
Button Input with Pull-up
#include <avr/io.h>
#include <util/delay.h>
int main(void) {
/* PB5 as output (LED) */
DDRB |= (1 << DDB5);
/* PD2 as input (button) */
DDRD &= ~(1 << DDD2);
/* Enable pull-up resistor on PD2 */
PORTD |= (1 << PORTD2);
while (1) {
/* Check if button pressed (active low) */
if (\!(PIND & (1 << PIND2))) {
PORTB |= (1 << PORTB5); // LED on
} else {
PORTB &= ~(1 << PORTB5); // LED off
}
_delay_ms(10); // Debounce delay
}
return 0;
}
Multiple LED Control
#include <avr/io.h>
#include <util/delay.h>
int main(void) {
/* Set PB0-PB5 as outputs */
DDRB = 0b00111111;
while (1) {
/* Running LED pattern */
for (uint8_t i = 0; i < 6; i++) {
PORTB = (1 << i);
_delay_ms(200);
}
/* Reverse */
for (uint8_t i = 6; i > 0; i--) {
PORTB = (1 << (i-1));
_delay_ms(200);
}
}
return 0;
}
Timers and Counters
AVR timers are versatile peripherals for timing, counting, PWM generation, and more.
Timer0 (8-bit)
#include <avr/io.h>
#include <avr/interrupt.h>
volatile uint32_t milliseconds = 0;
/* Timer0 overflow interrupt */
ISR(TIMER0_OVF_vect) {
milliseconds++;
}
void timer0_init(void) {
/* Set prescaler to 64 */
TCCR0B |= (1 << CS01) | (1 << CS00);
/* Enable overflow interrupt */
TIMSK0 |= (1 << TOIE0);
/* Enable global interrupts */
sei();
}
int main(void) {
DDRB |= (1 << DDB5);
timer0_init();
while (1) {
if (milliseconds >= 1000) {
milliseconds = 0;
PORTB ^= (1 << PORTB5);
}
}
return 0;
}
PWM with Timer1 (16-bit)
#include <avr/io.h>
#include <util/delay.h>
void pwm_init(void) {
/* Set PB1 (OC1A) as output */
DDRB |= (1 << DDB1);
/* Fast PWM, 10-bit, non-inverted */
TCCR1A |= (1 << WGM11) | (1 << WGM10);
TCCR1A |= (1 << COM1A1);
TCCR1B |= (1 << WGM12) | (1 << CS10); // No prescaling
/* Set initial duty cycle */
OCR1A = 512; // 50% duty cycle (0-1023)
}
int main(void) {
pwm_init();
while (1) {
/* Fade in */
for (uint16_t i = 0; i <= 1023; i += 10) {
OCR1A = i;
_delay_ms(20);
}
/* Fade out */
for (uint16_t i = 1023; i > 0; i -= 10) {
OCR1A = i;
_delay_ms(20);
}
}
return 0;
}
Timer2 CTC Mode (Precise Timing)
#include <avr/io.h>
#include <avr/interrupt.h>
volatile uint8_t flag = 0;
/* Timer2 compare match interrupt - fires every 1ms */
ISR(TIMER2_COMPA_vect) {
static uint16_t count = 0;
count++;
if (count >= 1000) { // 1 second
count = 0;
flag = 1;
}
}
void timer2_init(void) {
/* CTC mode */
TCCR2A |= (1 << WGM21);
/* Prescaler 64: 16MHz / 64 = 250kHz */
TCCR2B |= (1 << CS22);
/* Compare value for 1ms: 250kHz / 250 = 1kHz */
OCR2A = 249;
/* Enable compare match interrupt */
TIMSK2 |= (1 << OCIE2A);
sei();
}
int main(void) {
DDRB |= (1 << DDB5);
timer2_init();
while (1) {
if (flag) {
flag = 0;
PORTB ^= (1 << PORTB5);
}
}
return 0;
}
Interrupts
External Interrupts
#include <avr/io.h>
#include <avr/interrupt.h>
volatile uint8_t led_state = 0;
/* INT0 interrupt handler */
ISR(INT0_vect) {
led_state = \!led_state;
if (led_state) {
PORTB |= (1 << PORTB5);
} else {
PORTB &= ~(1 << PORTB5);
}
}
void int0_init(void) {
/* PD2 as input with pull-up */
DDRD &= ~(1 << DDD2);
PORTD |= (1 << PORTD2);
/* Trigger on falling edge */
EICRA |= (1 << ISC01);
/* Enable INT0 */
EIMSK |= (1 << INT0);
sei();
}
int main(void) {
DDRB |= (1 << DDB5);
int0_init();
while (1) {
/* Main loop can do other things */
}
return 0;
}
Pin Change Interrupts
#include <avr/io.h>
#include <avr/interrupt.h>
/* PCINT0 interrupt (PB0-PB7) */
ISR(PCINT0_vect) {
/* Check which pin changed */
if (\!(PINB & (1 << PINB0))) {
// PB0 is low
PORTB |= (1 << PORTB5);
} else {
PORTB &= ~(1 << PORTB5);
}
}
void pcint_init(void) {
/* Enable pull-up on PB0 */
PORTB |= (1 << PORTB0);
/* Enable PCINT0 (PB0) */
PCMSK0 |= (1 << PCINT0);
/* Enable pin change interrupt 0 */
PCICR |= (1 << PCIE0);
sei();
}
int main(void) {
DDRB |= (1 << DDB5);
DDRB &= ~(1 << DDB0);
pcint_init();
while (1) {
/* Main loop */
}
return 0;
}
Communication Protocols
UART (Serial Communication)
#include <avr/io.h>
#include <util/delay.h>
#define BAUD 9600
#define MYUBRR F_CPU/16/BAUD-1
void uart_init(void) {
/* Set baud rate */
UBRR0H = (MYUBRR >> 8);
UBRR0L = MYUBRR;
/* Enable transmitter and receiver */
UCSR0B = (1 << TXEN0) | (1 << RXEN0);
/* Set frame format: 8 data bits, 1 stop bit */
UCSR0C = (1 << UCSZ01) | (1 << UCSZ00);
}
void uart_transmit(uint8_t data) {
/* Wait for empty transmit buffer */
while (\!(UCSR0A & (1 << UDRE0)));
/* Put data into buffer */
UDR0 = data;
}
uint8_t uart_receive(void) {
/* Wait for data */
while (\!(UCSR0A & (1 << RXC0)));
/* Get and return data */
return UDR0;
}
void uart_print(const char* str) {
while (*str) {
uart_transmit(*str++);
}
}
int main(void) {
uart_init();
uart_print("Hello, AVR\!\r\n");
while (1) {
uint8_t received = uart_receive();
uart_transmit(received); // Echo back
}
return 0;
}
SPI Master
#include <avr/io.h>
void spi_init(void) {
/* Set MOSI, SCK, and SS as outputs */
DDRB |= (1 << DDB3) | (1 << DDB5) | (1 << DDB2);
/* Set MISO as input */
DDRB &= ~(1 << DDB4);
/* Enable SPI, Master mode, clock = F_CPU/16 */
SPCR = (1 << SPE) | (1 << MSTR) | (1 << SPR0);
}
uint8_t spi_transfer(uint8_t data) {
/* Start transmission */
SPDR = data;
/* Wait for transmission complete */
while (\!(SPSR & (1 << SPIF)));
/* Return received data */
return SPDR;
}
int main(void) {
spi_init();
while (1) {
/* Select device (SS low) */
PORTB &= ~(1 << PORTB2);
/* Send data */
spi_transfer(0xAB);
uint8_t received = spi_transfer(0x00);
/* Deselect device (SS high) */
PORTB |= (1 << PORTB2);
}
return 0;
}
I2C (TWI) Master
#include <avr/io.h>
#include <util/twi.h>
#define F_SCL 100000UL // 100 kHz
#define TWI_BITRATE ((F_CPU / F_SCL) - 16) / 2
void i2c_init(void) {
/* Set bit rate */
TWBR = (uint8_t)TWI_BITRATE;
/* Enable TWI */
TWCR = (1 << TWEN);
}
void i2c_start(void) {
/* Send start condition */
TWCR = (1 << TWINT) | (1 << TWSTA) | (1 << TWEN);
/* Wait for completion */
while (\!(TWCR & (1 << TWINT)));
}
void i2c_stop(void) {
/* Send stop condition */
TWCR = (1 << TWINT) | (1 << TWSTO) | (1 << TWEN);
}
void i2c_write(uint8_t data) {
/* Load data */
TWDR = data;
/* Start transmission */
TWCR = (1 << TWINT) | (1 << TWEN);
/* Wait for completion */
while (\!(TWCR & (1 << TWINT)));
}
uint8_t i2c_read_ack(void) {
/* Enable ACK */
TWCR = (1 << TWINT) | (1 << TWEN) | (1 << TWEA);
/* Wait for completion */
while (\!(TWCR & (1 << TWINT)));
return TWDR;
}
uint8_t i2c_read_nack(void) {
/* Enable NACK */
TWCR = (1 << TWINT) | (1 << TWEN);
/* Wait for completion */
while (\!(TWCR & (1 << TWINT)));
return TWDR;
}
int main(void) {
i2c_init();
uint8_t device_addr = 0x68 << 1; // 7-bit address
uint8_t reg_addr = 0x00;
while (1) {
/* Write to device */
i2c_start();
i2c_write(device_addr | 0); // Write mode
i2c_write(reg_addr);
i2c_write(0x42); // Data
i2c_stop();
/* Read from device */
i2c_start();
i2c_write(device_addr | 0); // Write mode
i2c_write(reg_addr);
i2c_start(); // Repeated start
i2c_write(device_addr | 1); // Read mode
uint8_t data = i2c_read_nack();
i2c_stop();
}
return 0;
}
Advanced Topics
ADC (Analog-to-Digital Converter)
#include <avr/io.h>
void adc_init(void) {
/* AVCC with external capacitor at AREF */
ADMUX = (1 << REFS0);
/* Enable ADC, prescaler 128 (125 kHz @ 16 MHz) */
ADCSRA = (1 << ADEN) | (1 << ADPS2) | (1 << ADPS1) | (1 << ADPS0);
}
uint16_t adc_read(uint8_t channel) {
/* Select channel (0-7) */
ADMUX = (ADMUX & 0xF0) | (channel & 0x0F);
/* Start conversion */
ADCSRA |= (1 << ADSC);
/* Wait for completion */
while (ADCSRA & (1 << ADSC));
return ADC;
}
int main(void) {
uart_init();
adc_init();
while (1) {
uint16_t value = adc_read(0); // Read ADC0
/* Convert to voltage (5V reference, 10-bit) */
float voltage = (value * 5.0) / 1024.0;
_delay_ms(100);
}
return 0;
}
EEPROM Access
#include <avr/io.h>
#include <avr/eeprom.h>
uint8_t EEMEM stored_value; // EEPROM variable
void eeprom_write_byte_custom(uint16_t address, uint8_t data) {
/* Wait for completion of previous write */
while (EECR & (1 << EEPE));
/* Set address and data registers */
EEAR = address;
EEDR = data;
/* Write logical one to EEMPE */
EECR |= (1 << EEMPE);
/* Start eeprom write by setting EEPE */
EECR |= (1 << EEPE);
}
uint8_t eeprom_read_byte_custom(uint16_t address) {
/* Wait for completion of previous write */
while (EECR & (1 << EEPE));
/* Set address register */
EEAR = address;
/* Start eeprom read by writing EERE */
EECR |= (1 << EERE);
/* Return data from data register */
return EEDR;
}
int main(void) {
/* Using avr-libc functions (recommended) */
eeprom_write_byte(&stored_value, 42);
uint8_t value = eeprom_read_byte(&stored_value);
/* Using custom functions */
eeprom_write_byte_custom(0, 100);
uint8_t val = eeprom_read_byte_custom(0);
while (1);
return 0;
}
Sleep Modes
#include <avr/io.h>
#include <avr/sleep.h>
#include <avr/interrupt.h>
ISR(INT0_vect) {
/* Wake up from sleep */
}
int main(void) {
/* Configure wake-up source */
EIMSK |= (1 << INT0);
sei();
/* Set sleep mode */
set_sleep_mode(SLEEP_MODE_PWR_DOWN);
while (1) {
/* Enter sleep mode */
sleep_mode();
/* Wake up here and continue */
PORTB ^= (1 << PORTB5);
}
return 0;
}
Watchdog Timer
#include <avr/io.h>
#include <avr/wdt.h>
int main(void) {
/* Disable watchdog on reset */
MCUSR &= ~(1 << WDRF);
wdt_disable();
/* Enable watchdog: 2 second timeout */
wdt_enable(WDTO_2S);
while (1) {
/* Main program */
/* Reset watchdog timer */
wdt_reset();
}
return 0;
}
Best Practices
- Use Register Macros:
PORTB |= (1 << PB5)instead ofPORTB |= 0x20 - Volatile for ISR Variables:
volatile uint8_t flag; - Minimize ISR Time: Keep interrupt handlers short
- Proper Delays: Use timers instead of
_delay_ms()for long delays - Power Management: Disable unused peripherals, use sleep modes
- Debouncing: Add delays or use interrupts with debounce logic
- Code Organization: Separate initialization from main loop
Troubleshooting
Common Issues
Program Not Running:
- Check fuse bits (clock source, brown-out detection)
- Verify F_CPU matches actual clock speed
- Ensure power supply is stable
Incorrect Baud Rate:
- Verify F_CPU is correct
- Check UBRR calculation
- Use standard baud rates
Fuse Bits:
# Read fuses
avrdude -c arduino -p atmega328p -U lfuse:r:-:h -U hfuse:r:-:h -U efuse:r:-:h
# Set fuses (CAREFUL\!)
# Default for Arduino Uno: lfuse=0xFF, hfuse=0xDE, efuse=0xFD
avrdude -c arduino -p atmega328p -U lfuse:w:0xFF:m -U hfuse:w:0xDE:m -U efuse:w:0xFD:m
Resources
- AVR Libc Documentation: https://www.nongnu.org/avr-libc/
- Datasheets: https://www.microchip.com/
- AVR Tutorials: https://www.avrfreaks.net/
- Community: AVRFreaks forum
See Also
- Arduino Programming - Higher-level AVR programming
- GPIO Concepts
- UART Communication
- SPI Protocol
- I2C Protocol
- Timers and PWM
STM32 Microcontrollers
Comprehensive guide to STM32 development using HAL, CubeMX, and bare-metal programming.
Table of Contents
- Introduction
- STM32 Families
- Development Setup
- STM32CubeMX
- HAL Programming
- Bare Metal Programming
- Common Peripherals
- Advanced Topics
Introduction
STM32 is a family of 32-bit microcontrollers from STMicroelectronics based on ARM Cortex-M cores. They offer excellent performance, rich peripherals, and are widely used in professional and industrial applications.
Key Features
- ARM Cortex-M Cores: M0, M0+, M3, M4, M7, M33
- Clock Speed: 48 MHz to 550 MHz
- Memory: 16 KB to 2 MB Flash, 4 KB to 1 MB RAM
- Peripherals: GPIO, UART, SPI, I2C, ADC, DAC, Timers, USB, CAN, Ethernet
- Development Tools: Free official IDE and HAL libraries
- Price: $1 to $20 depending on series
Advantages
- Professional-grade reliability
- Extensive peripheral set
- Low power consumption
- Strong ecosystem and support
- Pin-compatible families
- Real-time performance
STM32 Families
Overview
| Family | Core | Speed | Flash | Use Case | Examples |
|---|---|---|---|---|---|
| F0 | M0 | 48 MHz | 16-256 KB | Entry-level, cost-sensitive | STM32F030 |
| F1 | M3 | 72 MHz | 16-512 KB | General purpose, classic | STM32F103 (Blue Pill) |
| F4 | M4 | 180 MHz | 256 KB-2 MB | High performance, DSP, FPU | STM32F407, F429 |
| F7 | M7 | 216 MHz | 512 KB-2 MB | Very high performance | STM32F746 |
| H7 | M7 | 480 MHz | 1-2 MB | Extreme performance | STM32H743 |
| L0/L4 | M0+/M4 | 32-80 MHz | 16-512 KB | Ultra-low power | STM32L476 |
| G0/G4 | M0+/M4 | 64-170 MHz | 32-512 KB | Mainstream, motor control | STM32G474 |
Popular Development Boards
STM32 Nucleo Boards
┌────────────────────────────────┐
│ STM32 Nucleo-64 │
│ │
│ ┌─────────────────┐ │
│ │ STM32 MCU │ │
│ │ (QFP64) │ │
│ └─────────────────┘ │
│ │
│ [CN7] ═══════════════ [CN10] │ Arduino Headers
│ [CN8] ═══════════════ [CN9] │
│ │
│ [CN1] ST-LINK V2-1 │
│ [USB] │
└────────────────────────────────┘
Features:
- Integrated ST-LINK debugger/programmer
- Arduino Uno R3 compatible headers
- Morpho extension headers (full pin access)
- Virtual COM port
- Price: ~$15
Blue Pill (STM32F103C8T6)
┌──────────────────────────┐
│ STM32F103C8T6 │
│ "Blue Pill" │
│ │
│ [USB] ═══════════ [SWD] │
│ │
│ ╔════════════════════╗ │
│ ║ Header Pins ║ │
│ ║ (40 pins total) ║ │
│ ╚════════════════════╝ │
│ │
│ [3.3V] [5V] [GND] │
└──────────────────────────┘
Specs:
- 72 MHz ARM Cortex-M3
- 64 KB Flash, 20 KB RAM
- 37 GPIO pins
- 2x SPI, 2x I2C, 3x USART
- 12-bit ADC, 2x DAC
- Price: ~$2
Development Setup
STM32CubeIDE (Recommended)
# Download from ST website:
# https://www.st.com/en/development-tools/stm32cubeide.html
# Linux installation:
sudo chmod +x st-stm32cubeide_*.sh
sudo ./st-stm32cubeide_*.sh
# Install udev rules for ST-LINK
sudo cp ~/STMicroelectronics/STM32Cube/STM32CubeIDE/Drivers/rules/*.* /etc/udev/rules.d/
sudo udevadm control --reload-rules
Alternative: Command Line Setup
# Install ARM toolchain
sudo apt install gcc-arm-none-eabi gdb-multiarch
# Install OpenOCD (programming/debugging)
sudo apt install openocd
# Install st-link utilities
sudo apt install stlink-tools
# Verify installation
arm-none-eabi-gcc --version
openocd --version
st-info --version
PlatformIO Setup
pip install platformio
# Create project
pio init --board nucleo_f401re
# platformio.ini
[env:nucleo_f401re]
platform = ststm32
board = nucleo_f401re
framework = arduino
# or framework = stm32cube
STM32CubeMX
STM32CubeMX is a graphical configuration tool that generates initialization code for STM32 microcontrollers.
Creating a Project
-
Start New Project
- File > New Project
- Select your MCU or board
- Click "Start Project"
-
Configure Clock
- Clock Configuration tab
- Set HSE/HSI source
- Configure PLL multipliers
- Set system clock (HCLK)
-
Configure Peripherals
- Pinout & Configuration tab
- Click on pins to assign functions
- Configure peripheral parameters
-
Generate Code
- Project Manager tab
- Set project name and location
- Select toolchain (STM32CubeIDE, Makefile, etc.)
- Click "Generate Code"
Example: Blink LED Configuration
1. Pinout Configuration:
- Find LED pin (e.g., PC13 on Blue Pill)
- Set as GPIO_Output
- Label it "LED"
2. GPIO Configuration:
- Mode: Output Push Pull
- Pull-up/Pull-down: No pull-up and no pull-down
- Maximum output speed: Low
- User Label: LED
3. Clock Configuration:
- HSE: 8 MHz (external crystal)
- PLL: ×9 (72 MHz system clock)
4. Generate Code
Project Structure
project/
├── Core/
│ ├── Inc/
│ │ ├── main.h
│ │ ├── stm32f1xx_it.h
│ │ └── stm32f1xx_hal_conf.h
│ └── Src/
│ ├── main.c
│ ├── stm32f1xx_it.c
│ └── system_stm32f1xx.c
├── Drivers/
│ ├── STM32F1xx_HAL_Driver/
│ └── CMSIS/
└── Makefile
HAL Programming
Basic HAL Blink
/* main.c - Generated by CubeMX */
#include "main.h"
GPIO_InitTypeDef GPIO_InitStruct = {0};
void SystemClock_Config(void);
static void MX_GPIO_Init(void);
int main(void) {
/* Initialize HAL Library */
HAL_Init();
/* Configure system clock */
SystemClock_Config();
/* Initialize GPIO */
MX_GPIO_Init();
/* Infinite loop */
while (1) {
HAL_GPIO_TogglePin(GPIOC, GPIO_PIN_13);
HAL_Delay(1000);
}
}
static void MX_GPIO_Init(void) {
/* Enable GPIO Clock */
__HAL_RCC_GPIOC_CLK_ENABLE();
/* Configure GPIO pin */
GPIO_InitStruct.Pin = GPIO_PIN_13;
GPIO_InitStruct.Mode = GPIO_MODE_OUTPUT_PP;
GPIO_InitStruct.Pull = GPIO_NOPULL;
GPIO_InitStruct.Speed = GPIO_SPEED_FREQ_LOW;
HAL_GPIO_Init(GPIOC, &GPIO_InitStruct);
}
void SystemClock_Config(void) {
/* Generated by CubeMX - configures clocks */
}
GPIO Functions
/* Write pin */
HAL_GPIO_WritePin(GPIOC, GPIO_PIN_13, GPIO_PIN_SET); // High
HAL_GPIO_WritePin(GPIOC, GPIO_PIN_13, GPIO_PIN_RESET); // Low
/* Toggle pin */
HAL_GPIO_TogglePin(GPIOC, GPIO_PIN_13);
/* Read pin */
GPIO_PinState state = HAL_GPIO_ReadPin(GPIOA, GPIO_PIN_0);
/* External interrupt */
HAL_GPIO_EXTI_IRQHandler(GPIO_PIN_0); // Call in ISR
void HAL_GPIO_EXTI_Callback(uint16_t GPIO_Pin); // Override this
Button with Interrupt
/* Configure button with external interrupt in CubeMX:
PA0 -> GPIO_EXTI0
Mode: External Interrupt Mode with Rising edge trigger detection
Pull-up: Pull-up
In NVIC tab: Enable EXTI line0 interrupt
*/
/* main.c */
volatile uint8_t button_pressed = 0;
int main(void) {
HAL_Init();
SystemClock_Config();
MX_GPIO_Init();
while (1) {
if (button_pressed) {
button_pressed = 0;
HAL_GPIO_TogglePin(GPIOC, GPIO_PIN_13);
}
}
}
/* Interrupt callback - implement this */
void HAL_GPIO_EXTI_Callback(uint16_t GPIO_Pin) {
if (GPIO_Pin == GPIO_PIN_0) {
button_pressed = 1;
}
}
/* stm32f1xx_it.c - Generated by CubeMX */
void EXTI0_IRQHandler(void) {
HAL_GPIO_EXTI_IRQHandler(GPIO_PIN_0);
}
UART Communication
/* Configure UART in CubeMX:
USART1: PA9 (TX), PA10 (RX)
Baud Rate: 115200
Word Length: 8 Bits
Stop Bits: 1
Parity: None
*/
UART_HandleTypeDef huart1;
int main(void) {
HAL_Init();
SystemClock_Config();
MX_USART1_UART_Init();
uint8_t msg[] = "Hello, STM32!\r\n";
HAL_UART_Transmit(&huart1, msg, sizeof(msg)-1, HAL_MAX_DELAY);
uint8_t rx_buffer[10];
while (1) {
/* Blocking receive */
HAL_UART_Receive(&huart1, rx_buffer, 1, HAL_MAX_DELAY);
/* Echo back */
HAL_UART_Transmit(&huart1, rx_buffer, 1, HAL_MAX_DELAY);
}
}
/* printf redirect */
int _write(int file, char *ptr, int len) {
HAL_UART_Transmit(&huart1, (uint8_t*)ptr, len, HAL_MAX_DELAY);
return len;
}
ADC Reading
/* Configure ADC in CubeMX:
ADC1, Channel 0 (PA0)
Resolution: 12 bits
Continuous Conversion: Disabled
*/
ADC_HandleTypeDef hadc1;
int main(void) {
HAL_Init();
SystemClock_Config();
MX_ADC1_Init();
while (1) {
HAL_ADC_Start(&hadc1);
HAL_ADC_PollForConversion(&hadc1, HAL_MAX_DELAY);
uint32_t adc_value = HAL_ADC_GetValue(&hadc1);
/* Convert to voltage (3.3V reference, 12-bit) */
float voltage = (adc_value * 3.3f) / 4096.0f;
printf("ADC: %lu, Voltage: %.2f V\r\n", adc_value, voltage);
HAL_Delay(1000);
}
}
PWM Output
/* Configure Timer in CubeMX:
TIM2, Channel 1 (PA0)
Mode: PWM Generation CH1
Prescaler: 72-1 (1 MHz timer clock)
Counter Period: 1000-1 (1 kHz PWM)
*/
TIM_HandleTypeDef htim2;
int main(void) {
HAL_Init();
SystemClock_Config();
MX_TIM2_Init();
/* Start PWM */
HAL_TIM_PWM_Start(&htim2, TIM_CHANNEL_1);
while (1) {
/* Fade in */
for (uint16_t duty = 0; duty <= 1000; duty += 10) {
__HAL_TIM_SET_COMPARE(&htim2, TIM_CHANNEL_1, duty);
HAL_Delay(10);
}
/* Fade out */
for (uint16_t duty = 1000; duty > 0; duty -= 10) {
__HAL_TIM_SET_COMPARE(&htim2, TIM_CHANNEL_1, duty);
HAL_Delay(10);
}
}
}
I2C Communication
/* Configure I2C in CubeMX:
I2C1: PB6 (SCL), PB7 (SDA)
Speed: 100 kHz (Standard Mode)
*/
I2C_HandleTypeDef hi2c1;
#define DEVICE_ADDR 0x68 << 1 // 7-bit address shifted
int main(void) {
HAL_Init();
SystemClock_Config();
MX_I2C1_Init();
uint8_t tx_data = 0x00;
uint8_t rx_data[2];
while (1) {
/* Write register address */
HAL_I2C_Master_Transmit(&hi2c1, DEVICE_ADDR, &tx_data, 1, HAL_MAX_DELAY);
/* Read data */
HAL_I2C_Master_Receive(&hi2c1, DEVICE_ADDR, rx_data, 2, HAL_MAX_DELAY);
HAL_Delay(1000);
}
}
SPI Communication
/* Configure SPI in CubeMX:
SPI1: PA5 (SCK), PA6 (MISO), PA7 (MOSI)
Mode: Master
Baud Rate Prescaler: 32
Data Size: 8 Bits
*/
SPI_HandleTypeDef hspi1;
#define CS_PIN GPIO_PIN_4
#define CS_PORT GPIOA
int main(void) {
HAL_Init();
SystemClock_Config();
MX_SPI1_Init();
MX_GPIO_Init(); // CS pin
uint8_t tx_data[] = {0x01, 0x02, 0x03};
uint8_t rx_data[3];
while (1) {
/* Select device */
HAL_GPIO_WritePin(CS_PORT, CS_PIN, GPIO_PIN_RESET);
/* Transfer data */
HAL_SPI_TransmitReceive(&hspi1, tx_data, rx_data, 3, HAL_MAX_DELAY);
/* Deselect device */
HAL_GPIO_WritePin(CS_PORT, CS_PIN, GPIO_PIN_SET);
HAL_Delay(1000);
}
}
Bare Metal Programming
Direct Register Access
/* Blink LED without HAL - STM32F103 */
#include "stm32f1xx.h"
int main(void) {
/* Enable GPIOC clock */
RCC->APB2ENR |= RCC_APB2ENR_IOPCEN;
/* Configure PC13 as output push-pull, max speed 2 MHz */
GPIOC->CRH &= ~(GPIO_CRH_MODE13 | GPIO_CRH_CNF13);
GPIOC->CRH |= GPIO_CRH_MODE13_1; // Output mode, 2 MHz
while (1) {
/* Toggle LED */
GPIOC->ODR ^= GPIO_ODR_ODR13;
/* Delay */
for (volatile uint32_t i = 0; i < 1000000; i++);
}
}
GPIO Register Operations
/* Set pin high */
GPIOC->BSRR = GPIO_BSRR_BS13; // Bit Set
/* Set pin low */
GPIOC->BSRR = GPIO_BSRR_BR13; // Bit Reset
/* Toggle pin */
GPIOC->ODR ^= GPIO_ODR_ODR13;
/* Read pin */
uint32_t state = GPIOA->IDR & GPIO_IDR_IDR0;
UART Bare Metal
/* Initialize UART1 - 115200 baud, 72 MHz clock */
void UART1_Init(void) {
/* Enable clocks */
RCC->APB2ENR |= RCC_APB2ENR_USART1EN | RCC_APB2ENR_IOPAEN;
/* Configure PA9 (TX) as alternate function push-pull */
GPIOA->CRH &= ~(GPIO_CRH_MODE9 | GPIO_CRH_CNF9);
GPIOA->CRH |= GPIO_CRH_MODE9_1 | GPIO_CRH_CNF9_1;
/* Configure PA10 (RX) as input floating */
GPIOA->CRH &= ~(GPIO_CRH_MODE10 | GPIO_CRH_CNF10);
GPIOA->CRH |= GPIO_CRH_CNF10_0;
/* Configure UART */
USART1->BRR = 0x271; // 115200 baud at 72 MHz
USART1->CR1 = USART_CR1_TE | USART_CR1_RE | USART_CR1_UE;
}
void UART1_SendChar(char c) {
while (!(USART1->SR & USART_SR_TXE));
USART1->DR = c;
}
char UART1_ReceiveChar(void) {
while (!(USART1->SR & USART_SR_RXNE));
return USART1->DR;
}
Timer Interrupt
/* Configure TIM2 for 1 second interrupt */
void TIM2_Init(void) {
/* Enable TIM2 clock */
RCC->APB1ENR |= RCC_APB1ENR_TIM2EN;
/* Configure timer:
72 MHz / 7200 = 10 kHz
10 kHz / 10000 = 1 Hz (1 second)
*/
TIM2->PSC = 7200 - 1; // Prescaler
TIM2->ARR = 10000 - 1; // Auto-reload
TIM2->DIER |= TIM_DIER_UIE; // Update interrupt enable
TIM2->CR1 |= TIM_CR1_CEN; // Enable timer
/* Enable interrupt in NVIC */
NVIC_EnableIRQ(TIM2_IRQn);
}
/* Interrupt handler */
void TIM2_IRQHandler(void) {
if (TIM2->SR & TIM_SR_UIF) {
TIM2->SR &= ~TIM_SR_UIF; // Clear interrupt flag
/* Toggle LED */
GPIOC->ODR ^= GPIO_ODR_ODR13;
}
}
Common Peripherals
DMA Transfer
/* Configure DMA for UART TX in CubeMX:
DMA1, Channel 4
Direction: Memory to Peripheral
Mode: Normal
*/
uint8_t tx_buffer[] = "Hello from DMA!\r\n";
int main(void) {
HAL_Init();
SystemClock_Config();
MX_USART1_UART_Init();
MX_DMA_Init();
while (1) {
HAL_UART_Transmit_DMA(&huart1, tx_buffer, sizeof(tx_buffer)-1);
HAL_Delay(1000);
}
}
/* DMA transfer complete callback */
void HAL_UART_TxCpltCallback(UART_HandleTypeDef *huart) {
/* Transfer complete - can start next */
}
RTC (Real-Time Clock)
/* Configure RTC in CubeMX:
RTC Activated
Clock Source: LSE (32.768 kHz)
*/
RTC_TimeTypeDef sTime;
RTC_DateTypeDef sDate;
int main(void) {
HAL_Init();
SystemClock_Config();
MX_RTC_Init();
/* Set time */
sTime.Hours = 12;
sTime.Minutes = 0;
sTime.Seconds = 0;
HAL_RTC_SetTime(&hrtc, &sTime, RTC_FORMAT_BIN);
/* Set date */
sDate.Year = 24;
sDate.Month = 1;
sDate.Date = 15;
HAL_RTC_SetDate(&hrtc, &sDate, RTC_FORMAT_BIN);
while (1) {
HAL_RTC_GetTime(&hrtc, &sTime, RTC_FORMAT_BIN);
HAL_RTC_GetDate(&hrtc, &sDate, RTC_FORMAT_BIN);
printf("%02d:%02d:%02d\r\n",
sTime.Hours, sTime.Minutes, sTime.Seconds);
HAL_Delay(1000);
}
}
Watchdog Timer
/* Configure IWDG in CubeMX:
Independent Watchdog
Prescaler: 32
Counter Reload Value: 4095 (max ~4 seconds)
*/
IWDG_HandleTypeDef hiwdg;
int main(void) {
HAL_Init();
SystemClock_Config();
MX_IWDG_Init();
while (1) {
/* Main program tasks */
/* Refresh watchdog */
HAL_IWDG_Refresh(&hiwdg);
HAL_Delay(100);
}
}
Advanced Topics
FreeRTOS Integration
/* Enable FreeRTOS in CubeMX */
#include "FreeRTOS.h"
#include "task.h"
void Task1(void *argument);
void Task2(void *argument);
int main(void) {
HAL_Init();
SystemClock_Config();
MX_GPIO_Init();
/* Create tasks */
xTaskCreate(Task1, "Task1", 128, NULL, 1, NULL);
xTaskCreate(Task2, "Task2", 128, NULL, 1, NULL);
/* Start scheduler */
vTaskStartScheduler();
/* Never reached */
while (1);
}
void Task1(void *argument) {
while (1) {
HAL_GPIO_TogglePin(GPIOC, GPIO_PIN_13);
vTaskDelay(pdMS_TO_TICKS(500));
}
}
void Task2(void *argument) {
while (1) {
/* Other task */
vTaskDelay(pdMS_TO_TICKS(1000));
}
}
Low Power Modes
/* Enter Stop mode */
HAL_PWR_EnterSTOPMode(PWR_LOWPOWERREGULATOR_ON, PWR_STOPENTRY_WFI);
/* Enter Standby mode */
HAL_PWR_EnterSTANDBYMode();
/* Enter Sleep mode */
HAL_PWR_EnterSLEEPMode(PWR_MAINREGULATOR_ON, PWR_SLEEPENTRY_WFI);
Bootloader
/* Jump to bootloader (system memory) */
void JumpToBootloader(void) {
void (*SysMemBootJump)(void);
/* Set bootloader address (STM32F1: 0x1FFFF000) */
volatile uint32_t addr = 0x1FFFF000;
/* Disable interrupts */
__disable_irq();
/* Remap system memory to 0x00000000 */
__HAL_RCC_SYSCFG_CLK_ENABLE();
__HAL_SYSCFG_REMAPMEMORY_SYSTEMFLASH();
/* Set jump address */
SysMemBootJump = (void (*)(void)) (*((uint32_t *)(addr + 4)));
/* Set main stack pointer */
__set_MSP(*(uint32_t *)addr);
/* Jump */
SysMemBootJump();
while (1);
}
Best Practices
- Use CubeMX: Generate initialization code automatically
- HAL vs LL: HAL for ease, LL for performance
- Interrupts: Keep ISRs short, use callbacks
- DMA: Use for high-speed data transfers
- Power: Disable unused peripherals
- Debugging: Use SWD with ST-LINK
- Version Control: Track CubeMX .ioc files
Troubleshooting
Common Issues
Debugger Not Connecting:
# Check ST-LINK connection
st-info --probe
# Reset ST-LINK
st-flash reset
# Update ST-LINK firmware
# Use STM32 ST-LINK Utility
Clock Configuration:
- Verify HSE frequency matches hardware
- Check PLL multipliers for target frequency
- Enable required peripheral clocks
GPIO Not Working:
- Enable GPIO clock first
- Check pin alternate functions
- Verify pin configuration (mode, speed, pull)
Printf Not Working:
// Enable semi-hosting or retarget _write()
int _write(int file, char *ptr, int len) {
HAL_UART_Transmit(&huart1, (uint8_t*)ptr, len, HAL_MAX_DELAY);
return len;
}
Resources
- STM32 Website: https://www.st.com/stm32
- CubeMX: https://www.st.com/en/development-tools/stm32cubemx.html
- HAL Documentation: STM32 HAL user manual per family
- Reference Manuals: Detailed peripheral descriptions
- Community: https://community.st.com/
See Also
ESP32
Comprehensive guide to ESP32 microcontroller development with WiFi and Bluetooth capabilities.
Table of Contents
- Introduction
- Hardware Overview
- Development Setup
- Basic Programming
- WiFi Connectivity
- Bluetooth
- Advanced Features
- Projects
Introduction
The ESP32 is a powerful, low-cost microcontroller with integrated WiFi and Bluetooth. Developed by Espressif Systems, it's ideal for IoT projects and wireless applications.
Key Features
- Dual-core Xtensa LX6 (or single-core RISC-V in ESP32-C3)
- Clock Speed: 160-240 MHz
- Memory: 520 KB SRAM, 4 MB Flash (typical)
- WiFi: 802.11 b/g/n (2.4 GHz)
- Bluetooth: BLE 4.2 and Classic Bluetooth
- GPIO: Up to 34 programmable pins
- Peripherals: ADC, DAC, SPI, I2C, UART, PWM, I2S
- Low Power: Multiple sleep modes
- Price: $2-$10 depending on variant
ESP32 Variants
| Variant | Cores | WiFi | BLE | Classic BT | USB | Special Features |
|---|---|---|---|---|---|---|
| ESP32 | 2 | Yes | Yes | Yes | No | Original, most common |
| ESP32-S2 | 1 | Yes | No | No | Native | USB OTG, lower power |
| ESP32-S3 | 2 | Yes | Yes | No | Native | Vector instructions |
| ESP32-C3 | 1 (RISC-V) | Yes | Yes | No | Native | RISC-V architecture |
| ESP32-C6 | 1 (RISC-V) | Yes | Yes | No | Native | WiFi 6, Zigbee |
Hardware Overview
ESP32 DevKit Pinout
ESP32 DevKit
┌─────────────────┐
│ USB │
├─────────────────┤
3V3 [ ]──┤3V3 D23├──[ ] MOSI
EN [ ]──┤EN D22├──[ ] SCL (I2C)
VP/36 [ ]──┤VP/A0 TX0├──[ ] TX
VN/39 [ ]──┤VN/A3 RX0├──[ ] RX
D34 [ ]──┤34/A6 D21├──[ ] SDA (I2C)
D35 [ ]──┤35/A7 GND├──[ ] GND
D32 [ ]──┤32/A4 D19├──[ ] MISO
D33 [ ]──┤33/A5 D18├──[ ] SCK
D25 [ ]──┤25/A18 D5 ├──[ ] SS
D26 [ ]──┤26/A19 D17├──[ ] TX2
D27 [ ]──┤27/A17 D16├──[ ] RX2
D14 [ ]──┤14/A16 D4 ├──[ ]
D12 [ ]──┤12/A15 D0 ├──[ ] (Boot)
GND [ ]──┤GND D2 ├──[ ] (LED)
D13 [ ]──┤13/A14 D15├──[ ]
D9 [ ]──┤9/SD2 D8├──[ ] SD1
D10 [ ]──┤10/SD3 D7├──[ ] SD0
D11 [ ]──┤11/CMD D6├──[ ] SCK
VIN [ ]──┤VIN 5V├──[ ] 5V
└─────────────────┘
Note: Pins 6-11 connected to flash (avoid using)
Pins with boot/strapping modes: 0, 2, 5, 12, 15
Important Notes
- Input Only Pins: GPIO 34-39 (no pull-up/pull-down)
- Strapping Pins: 0, 2, 5, 12, 15 (affect boot mode)
- Boot Mode: GPIO 0 LOW = download mode
- Built-in LED: Usually GPIO 2
- ADC2: Cannot use while WiFi active (GPIO 0, 2, 4, 12-15, 25-27)
Development Setup
Arduino IDE Setup
# 1. Install Arduino IDE from arduino.cc
# 2. Add ESP32 Board Manager URL:
# File > Preferences > Additional Board Manager URLs
# Add: https://raw.githubusercontent.com/espressif/arduino-esp32/gh-pages/package_esp32_index.json
# 3. Install ESP32 boards:
# Tools > Board > Boards Manager > Search "ESP32" > Install
# 4. Select your board:
# Tools > Board > ESP32 Arduino > ESP32 Dev Module
ESP-IDF Setup (Official Framework)
# Clone ESP-IDF
git clone --recursive https://github.com/espressif/esp-idf.git
cd esp-idf
# Install (Linux/Mac)
./install.sh
# Set up environment (run in each terminal session)
. ./export.sh
# Or add to ~/.bashrc:
alias get_idf='. $HOME/esp/esp-idf/export.sh'
# Create new project
idf.py create-project myproject
cd myproject
# Configure
idf.py menuconfig
# Build
idf.py build
# Flash
idf.py -p /dev/ttyUSB0 flash
# Monitor serial output
idf.py -p /dev/ttyUSB0 monitor
PlatformIO Setup
# Install PlatformIO
pip install platformio
# Create ESP32 project
pio init --board esp32dev
# Build and upload
pio run --target upload
# Serial monitor
pio device monitor
Basic Programming
Blink LED (Arduino Framework)
#define LED_PIN 2
void setup() {
pinMode(LED_PIN, OUTPUT);
}
void loop() {
digitalWrite(LED_PIN, HIGH);
delay(1000);
digitalWrite(LED_PIN, LOW);
delay(1000);
}
Dual Core Programming
TaskHandle_t Task1;
TaskHandle_t Task2;
void setup() {
Serial.begin(115200);
// Create task for core 0
xTaskCreatePinnedToCore(
Task1code, // Function
"Task1", // Name
10000, // Stack size
NULL, // Parameters
1, // Priority
&Task1, // Task handle
0 // Core ID
);
// Create task for core 1
xTaskCreatePinnedToCore(
Task2code,
"Task2",
10000,
NULL,
1,
&Task2,
1
);
}
void Task1code(void * parameter) {
while(1) {
Serial.print("Task 1 running on core ");
Serial.println(xPortGetCoreID());
delay(1000);
}
}
void Task2code(void * parameter) {
while(1) {
Serial.print("Task 2 running on core ");
Serial.println(xPortGetCoreID());
delay(500);
}
}
void loop() {
// Empty - tasks handle everything
}
Touch Sensor
const int TOUCH_PIN = 4; // T0
const int THRESHOLD = 40;
void setup() {
Serial.begin(115200);
}
void loop() {
int touchValue = touchRead(TOUCH_PIN);
Serial.println(touchValue);
if (touchValue < THRESHOLD) {
Serial.println("Touch detected!");
}
delay(500);
}
Hall Effect Sensor (Built-in)
void setup() {
Serial.begin(115200);
}
void loop() {
// Read built-in Hall effect sensor
int hallValue = hallRead();
Serial.print("Hall Sensor: ");
Serial.println(hallValue);
delay(500);
}
WiFi Connectivity
WiFi Station Mode (Connect to Router)
#include <WiFi.h>
const char* ssid = "YourSSID";
const char* password = "YourPassword";
void setup() {
Serial.begin(115200);
// Connect to WiFi
WiFi.begin(ssid, password);
Serial.print("Connecting to WiFi");
while (WiFi.status() != WL_CONNECTED) {
delay(500);
Serial.print(".");
}
Serial.println("\nConnected!");
Serial.print("IP Address: ");
Serial.println(WiFi.localIP());
Serial.print("MAC Address: ");
Serial.println(WiFi.macAddress());
Serial.print("Signal Strength (RSSI): ");
Serial.print(WiFi.RSSI());
Serial.println(" dBm");
}
void loop() {
// Check connection
if (WiFi.status() != WL_CONNECTED) {
Serial.println("WiFi disconnected!");
WiFi.reconnect();
}
delay(10000);
}
WiFi Access Point Mode
#include <WiFi.h>
const char* ssid = "ESP32-AP";
const char* password = "12345678"; // Minimum 8 characters
void setup() {
Serial.begin(115200);
// Start Access Point
WiFi.softAP(ssid, password);
IPAddress IP = WiFi.softAPIP();
Serial.print("AP IP address: ");
Serial.println(IP);
}
void loop() {
// Print number of connected stations
Serial.print("Stations connected: ");
Serial.println(WiFi.softAPgetStationNum());
delay(5000);
}
Web Server
#include <WiFi.h>
#include <WebServer.h>
const char* ssid = "YourSSID";
const char* password = "YourPassword";
WebServer server(80);
const int LED_PIN = 2;
bool ledState = false;
void handleRoot() {
String html = "<html><body>";
html += "<h1>ESP32 Web Server</h1>";
html += "<p>LED is: " + String(ledState ? "ON" : "OFF") + "</p>";
html += "<p><a href=\"/led/on\"><button>Turn ON</button></a></p>";
html += "<p><a href=\"/led/off\"><button>Turn OFF</button></a></p>";
html += "</body></html>";
server.send(200, "text/html", html);
}
void handleLEDOn() {
ledState = true;
digitalWrite(LED_PIN, HIGH);
server.sendHeader("Location", "/");
server.send(303);
}
void handleLEDOff() {
ledState = false;
digitalWrite(LED_PIN, LOW);
server.sendHeader("Location", "/");
server.send(303);
}
void setup() {
Serial.begin(115200);
pinMode(LED_PIN, OUTPUT);
// Connect to WiFi
WiFi.begin(ssid, password);
while (WiFi.status() != WL_CONNECTED) {
delay(500);
Serial.print(".");
}
Serial.println("\nConnected!");
Serial.print("IP: ");
Serial.println(WiFi.localIP());
// Setup server routes
server.on("/", handleRoot);
server.on("/led/on", handleLEDOn);
server.on("/led/off", handleLEDOff);
server.begin();
Serial.println("Web server started");
}
void loop() {
server.handleClient();
}
HTTP Client (GET Request)
#include <WiFi.h>
#include <HTTPClient.h>
const char* ssid = "YourSSID";
const char* password = "YourPassword";
void setup() {
Serial.begin(115200);
WiFi.begin(ssid, password);
while (WiFi.status() != WL_CONNECTED) {
delay(500);
Serial.print(".");
}
Serial.println("\nConnected!");
}
void loop() {
if (WiFi.status() == WL_CONNECTED) {
HTTPClient http;
http.begin("http://api.github.com/users/octocat");
int httpCode = http.GET();
if (httpCode > 0) {
Serial.printf("HTTP Code: %d\n", httpCode);
if (httpCode == HTTP_CODE_OK) {
String payload = http.getString();
Serial.println(payload);
}
} else {
Serial.printf("Error: %s\n", http.errorToString(httpCode).c_str());
}
http.end();
}
delay(10000);
}
WiFi Scan
#include <WiFi.h>
void setup() {
Serial.begin(115200);
WiFi.mode(WIFI_STA);
WiFi.disconnect();
delay(100);
}
void loop() {
Serial.println("Scanning WiFi networks...");
int n = WiFi.scanNetworks();
if (n == 0) {
Serial.println("No networks found");
} else {
Serial.printf("%d networks found:\n", n);
for (int i = 0; i < n; i++) {
Serial.printf("%d: %s (%d dBm) %s\n",
i + 1,
WiFi.SSID(i).c_str(),
WiFi.RSSI(i),
WiFi.encryptionType(i) == WIFI_AUTH_OPEN ? "Open" : "Encrypted"
);
}
}
delay(5000);
}
MQTT Client
#include <WiFi.h>
#include <PubSubClient.h>
const char* ssid = "YourSSID";
const char* password = "YourPassword";
const char* mqtt_server = "broker.hivemq.com";
WiFiClient espClient;
PubSubClient client(espClient);
void callback(char* topic, byte* payload, unsigned int length) {
Serial.print("Message arrived [");
Serial.print(topic);
Serial.print("]: ");
for (int i = 0; i < length; i++) {
Serial.print((char)payload[i]);
}
Serial.println();
}
void reconnect() {
while (!client.connected()) {
Serial.print("Attempting MQTT connection...");
if (client.connect("ESP32Client")) {
Serial.println("connected");
client.subscribe("esp32/test");
} else {
Serial.print("failed, rc=");
Serial.print(client.state());
Serial.println(" retrying in 5 seconds");
delay(5000);
}
}
}
void setup() {
Serial.begin(115200);
WiFi.begin(ssid, password);
while (WiFi.status() != WL_CONNECTED) {
delay(500);
Serial.print(".");
}
Serial.println("\nWiFi connected");
client.setServer(mqtt_server, 1883);
client.setCallback(callback);
}
void loop() {
if (!client.connected()) {
reconnect();
}
client.loop();
// Publish message every 10 seconds
static unsigned long lastMsg = 0;
unsigned long now = millis();
if (now - lastMsg > 10000) {
lastMsg = now;
char msg[50];
snprintf(msg, 50, "Hello from ESP32 #%lu", millis());
client.publish("esp32/test", msg);
}
}
Bluetooth
Bluetooth Classic - Serial
#include <BluetoothSerial.h>
BluetoothSerial SerialBT;
void setup() {
Serial.begin(115200);
SerialBT.begin("ESP32-BT"); // Bluetooth device name
Serial.println("Bluetooth Started! Pair with 'ESP32-BT'");
}
void loop() {
// Forward from Serial to Bluetooth
if (Serial.available()) {
SerialBT.write(Serial.read());
}
// Forward from Bluetooth to Serial
if (SerialBT.available()) {
Serial.write(SerialBT.read());
}
}
BLE Server
#include <BLEDevice.h>
#include <BLEServer.h>
#include <BLEUtils.h>
#include <BLE2902.h>
BLEServer* pServer = NULL;
BLECharacteristic* pCharacteristic = NULL;
bool deviceConnected = false;
uint32_t value = 0;
#define SERVICE_UUID "4fafc201-1fb5-459e-8fcc-c5c9c331914b"
#define CHARACTERISTIC_UUID "beb5483e-36e1-4688-b7f5-ea07361b26a8"
class MyServerCallbacks: public BLEServerCallbacks {
void onConnect(BLEServer* pServer) {
deviceConnected = true;
Serial.println("Device connected");
}
void onDisconnect(BLEServer* pServer) {
deviceConnected = false;
Serial.println("Device disconnected");
}
};
void setup() {
Serial.begin(115200);
// Create BLE Device
BLEDevice::init("ESP32-BLE");
// Create BLE Server
pServer = BLEDevice::createServer();
pServer->setCallbacks(new MyServerCallbacks());
// Create BLE Service
BLEService *pService = pServer->createService(SERVICE_UUID);
// Create BLE Characteristic
pCharacteristic = pService->createCharacteristic(
CHARACTERISTIC_UUID,
BLECharacteristic::PROPERTY_READ |
BLECharacteristic::PROPERTY_WRITE |
BLECharacteristic::PROPERTY_NOTIFY
);
pCharacteristic->addDescriptor(new BLE2902());
// Start service
pService->start();
// Start advertising
BLEAdvertising *pAdvertising = BLEDevice::getAdvertising();
pAdvertising->addServiceUUID(SERVICE_UUID);
pAdvertising->start();
Serial.println("BLE Server started. Waiting for connections...");
}
void loop() {
if (deviceConnected) {
// Update and notify characteristic
pCharacteristic->setValue((uint8_t*)&value, 4);
pCharacteristic->notify();
value++;
delay(1000);
}
}
BLE Client (Scanner)
#include <BLEDevice.h>
#include <BLEUtils.h>
#include <BLEScan.h>
#include <BLEAdvertisedDevice.h>
BLEScan* pBLEScan;
class MyAdvertisedDeviceCallbacks: public BLEAdvertisedDeviceCallbacks {
void onResult(BLEAdvertisedDevice advertisedDevice) {
Serial.printf("Found device: %s\n", advertisedDevice.toString().c_str());
}
};
void setup() {
Serial.begin(115200);
Serial.println("Starting BLE scan...");
BLEDevice::init("");
pBLEScan = BLEDevice::getScan();
pBLEScan->setAdvertisedDeviceCallbacks(new MyAdvertisedDeviceCallbacks());
pBLEScan->setActiveScan(true);
}
void loop() {
BLEScanResults foundDevices = pBLEScan->start(5, false);
Serial.printf("Devices found: %d\n", foundDevices.getCount());
pBLEScan->clearResults();
delay(2000);
}
Advanced Features
Deep Sleep Mode
#define uS_TO_S_FACTOR 1000000 // Conversion factor for microseconds to seconds
#define TIME_TO_SLEEP 30 // Sleep for 30 seconds
RTC_DATA_ATTR int bootCount = 0; // Preserved in RTC memory
void setup() {
Serial.begin(115200);
delay(1000);
bootCount++;
Serial.println("Boot number: " + String(bootCount));
// Configure wake-up sources
esp_sleep_enable_timer_wakeup(TIME_TO_SLEEP * uS_TO_S_FACTOR);
// GPIO wake-up
esp_sleep_enable_ext0_wakeup(GPIO_NUM_33, 1); // Wake on HIGH
Serial.println("Going to sleep for " + String(TIME_TO_SLEEP) + " seconds");
Serial.flush();
esp_deep_sleep_start();
}
void loop() {
// Never reached
}
Touch Wake-up
#define THRESHOLD 40
void setup() {
Serial.begin(115200);
delay(1000);
// Configure touch wake-up
touchAttachInterrupt(T0, callback, THRESHOLD);
esp_sleep_enable_touchpad_wakeup();
Serial.println("Going to sleep. Touch GPIO 4 to wake up.");
delay(1000);
esp_deep_sleep_start();
}
void callback() {
// Empty
}
void loop() {
// Never reached
}
ADC with Calibration
#include <esp_adc_cal.h>
#define ADC_PIN 34
#define DEFAULT_VREF 1100
esp_adc_cal_characteristics_t *adc_chars;
void setup() {
Serial.begin(115200);
// Configure ADC
adc1_config_width(ADC_WIDTH_BIT_12);
adc1_config_channel_atten(ADC1_CHANNEL_6, ADC_ATTEN_DB_11);
// Characterize ADC
adc_chars = (esp_adc_cal_characteristics_t*)calloc(1, sizeof(esp_adc_cal_characteristics_t));
esp_adc_cal_characterize(ADC_UNIT_1, ADC_ATTEN_DB_11, ADC_WIDTH_BIT_12, DEFAULT_VREF, adc_chars);
}
void loop() {
uint32_t adc_reading = 0;
// Multisampling
for (int i = 0; i < 64; i++) {
adc_reading += adc1_get_raw(ADC1_CHANNEL_6);
}
adc_reading /= 64;
// Convert to voltage
uint32_t voltage = esp_adc_cal_raw_to_voltage(adc_reading, adc_chars);
Serial.printf("Raw: %d, Voltage: %d mV\n", adc_reading, voltage);
delay(1000);
}
Over-The-Air (OTA) Updates
#include <WiFi.h>
#include <ESPmDNS.h>
#include <WiFiUdp.h>
#include <ArduinoOTA.h>
const char* ssid = "YourSSID";
const char* password = "YourPassword";
void setup() {
Serial.begin(115200);
WiFi.begin(ssid, password);
while (WiFi.status() != WL_CONNECTED) {
delay(500);
Serial.print(".");
}
Serial.println("\nWiFi connected");
Serial.println(WiFi.localIP());
// Setup OTA
ArduinoOTA.setHostname("esp32");
ArduinoOTA.setPassword("admin");
ArduinoOTA.onStart([]() {
String type = (ArduinoOTA.getCommand() == U_FLASH) ? "sketch" : "filesystem";
Serial.println("Start updating " + type);
});
ArduinoOTA.onEnd([]() {
Serial.println("\nEnd");
});
ArduinoOTA.onProgress([](unsigned int progress, unsigned int total) {
Serial.printf("Progress: %u%%\r", (progress / (total / 100)));
});
ArduinoOTA.onError([](ota_error_t error) {
Serial.printf("Error[%u]: ", error);
if (error == OTA_AUTH_ERROR) Serial.println("Auth Failed");
else if (error == OTA_BEGIN_ERROR) Serial.println("Begin Failed");
else if (error == OTA_CONNECT_ERROR) Serial.println("Connect Failed");
else if (error == OTA_RECEIVE_ERROR) Serial.println("Receive Failed");
else if (error == OTA_END_ERROR) Serial.println("End Failed");
});
ArduinoOTA.begin();
Serial.println("OTA Ready");
}
void loop() {
ArduinoOTA.handle();
}
Projects
Project 1: WiFi Weather Station
#include <WiFi.h>
#include <HTTPClient.h>
#include <DHT.h>
#define DHTPIN 4
#define DHTTYPE DHT11
DHT dht(DHTPIN, DHTTYPE);
const char* ssid = "YourSSID";
const char* password = "YourPassword";
const char* server = "http://api.thingspeak.com/update";
const char* apiKey = "YOUR_API_KEY";
void setup() {
Serial.begin(115200);
dht.begin();
WiFi.begin(ssid, password);
while (WiFi.status() != WL_CONNECTED) {
delay(500);
Serial.print(".");
}
Serial.println("\nConnected!");
}
void loop() {
float temp = dht.readTemperature();
float humidity = dht.readHumidity();
if (isnan(temp) || isnan(humidity)) {
Serial.println("Failed to read sensor!");
delay(2000);
return;
}
Serial.printf("Temp: %.1f°C, Humidity: %.1f%%\n", temp, humidity);
// Send to ThingSpeak
if (WiFi.status() == WL_CONNECTED) {
HTTPClient http;
String url = String(server) + "?api_key=" + apiKey +
"&field1=" + String(temp) +
"&field2=" + String(humidity);
http.begin(url);
int httpCode = http.GET();
if (httpCode > 0) {
Serial.println("Data sent successfully");
} else {
Serial.println("Error sending data");
}
http.end();
}
delay(20000); // ThingSpeak requires 15 second minimum
}
Project 2: Bluetooth LED Controller
#include <BluetoothSerial.h>
BluetoothSerial SerialBT;
const int RED_PIN = 25;
const int GREEN_PIN = 26;
const int BLUE_PIN = 27;
void setup() {
Serial.begin(115200);
SerialBT.begin("ESP32-RGB");
pinMode(RED_PIN, OUTPUT);
pinMode(GREEN_PIN, OUTPUT);
pinMode(BLUE_PIN, OUTPUT);
Serial.println("Bluetooth RGB Controller Ready");
}
void loop() {
if (SerialBT.available()) {
String command = SerialBT.readStringUntil('\n');
command.trim();
if (command.startsWith("RGB")) {
// Format: RGB,255,128,64
int comma1 = command.indexOf(',');
int comma2 = command.indexOf(',', comma1 + 1);
int comma3 = command.indexOf(',', comma2 + 1);
int r = command.substring(comma1 + 1, comma2).toInt();
int g = command.substring(comma2 + 1, comma3).toInt();
int b = command.substring(comma3 + 1).toInt();
analogWrite(RED_PIN, r);
analogWrite(GREEN_PIN, g);
analogWrite(BLUE_PIN, b);
SerialBT.printf("Set RGB to %d,%d,%d\n", r, g, b);
} else if (command == "OFF") {
analogWrite(RED_PIN, 0);
analogWrite(GREEN_PIN, 0);
analogWrite(BLUE_PIN, 0);
SerialBT.println("LEDs OFF");
}
}
}
Project 3: WiFi Smart Thermostat
#include <WiFi.h>
#include <WebServer.h>
#include <DHT.h>
#define DHTPIN 4
#define DHTTYPE DHT11
#define RELAY_PIN 26
DHT dht(DHTPIN, DHTTYPE);
WebServer server(80);
const char* ssid = "YourSSID";
const char* password = "YourPassword";
float targetTemp = 25.0;
bool heatingOn = false;
void handleRoot() {
float temp = dht.readTemperature();
float humidity = dht.readHumidity();
String html = "<!DOCTYPE html><html><head>";
html += "<meta name='viewport' content='width=device-width, initial-scale=1'>";
html += "<style>body{font-family:Arial;text-align:center;margin:20px;}";
html += ".button{padding:15px;margin:10px;font-size:20px;}</style></head>";
html += "<body><h1>Smart Thermostat</h1>";
html += "<p>Current: " + String(temp, 1) + "°C</p>";
html += "<p>Humidity: " + String(humidity, 1) + "%</p>";
html += "<p>Target: " + String(targetTemp, 1) + "°C</p>";
html += "<p>Heating: " + String(heatingOn ? "ON" : "OFF") + "</p>";
html += "<a href='/increase'><button class='button'>+1°C</button></a>";
html += "<a href='/decrease'><button class='button'>-1°C</button></a>";
html += "</body></html>";
server.send(200, "text/html", html);
}
void handleIncrease() {
targetTemp += 1.0;
server.sendHeader("Location", "/");
server.send(303);
}
void handleDecrease() {
targetTemp -= 1.0;
server.sendHeader("Location", "/");
server.send(303);
}
void setup() {
Serial.begin(115200);
dht.begin();
pinMode(RELAY_PIN, OUTPUT);
WiFi.begin(ssid, password);
while (WiFi.status() != WL_CONNECTED) {
delay(500);
}
Serial.println("\nConnected!");
Serial.println(WiFi.localIP());
server.on("/", handleRoot);
server.on("/increase", handleIncrease);
server.on("/decrease", handleDecrease);
server.begin();
}
void loop() {
server.handleClient();
static unsigned long lastCheck = 0;
if (millis() - lastCheck > 5000) {
lastCheck = millis();
float temp = dht.readTemperature();
if (!isnan(temp)) {
if (temp < targetTemp - 0.5) {
heatingOn = true;
digitalWrite(RELAY_PIN, HIGH);
} else if (temp > targetTemp + 0.5) {
heatingOn = false;
digitalWrite(RELAY_PIN, LOW);
}
}
}
}
Best Practices
- Power Management: Use deep sleep for battery-powered projects
- WiFi: Disconnect when not needed to save power
- Watchdog: Enable watchdog timer for reliability
- OTA Updates: Implement for remote firmware updates
- Error Handling: Always check WiFi connection status
- Security: Use HTTPS and encrypted connections
- Memory: Monitor heap usage with
ESP.getFreeHeap()
Troubleshooting
Common Issues
Boot Loop:
- Check strapping pins (0, 2, 5, 12, 15)
- Ensure stable power supply (500mA minimum)
- Add 10µF capacitor on EN pin
WiFi Not Connecting:
- Check SSID and password
- Verify 2.4GHz network (ESP32 doesn't support 5GHz)
- Move closer to router
Upload Failed:
- Hold BOOT button during upload
- Check correct COM port selected
- Try lower baud rate (115200)
Brown-out Detector:
- Use external 5V power supply
- Add bulk capacitor (100-1000µF)
Resources
- Espressif Documentation: https://docs.espressif.com/
- ESP32 Arduino Core: https://github.com/espressif/arduino-esp32
- ESP-IDF Programming Guide: https://docs.espressif.com/projects/esp-idf/
- Forum: https://www.esp32.com/
See Also
Raspberry Pi
Complete guide to Raspberry Pi setup, GPIO programming, and projects.
Table of Contents
- Introduction
- Hardware Overview
- Setup and Installation
- GPIO Programming
- Python Programming
- C/C++ Programming
- Interfaces
- Projects
Introduction
The Raspberry Pi is a series of small single-board computers developed by the Raspberry Pi Foundation. Unlike microcontrollers, it runs a full Linux operating system and can function as a complete desktop computer.
Key Features
- Full Linux Operating System (Raspberry Pi OS based on Debian)
- High Processing Power: Multi-core ARM processors
- Rich Connectivity: USB, Ethernet, WiFi, Bluetooth, HDMI
- GPIO Interface: 40-pin header for hardware projects
- Programming: Python, C/C++, JavaScript, and more
- Price: $15-$75 depending on model
Model Comparison
| Model | Processor | RAM | USB | Ethernet | WiFi/BT | GPIO | Price |
|---|---|---|---|---|---|---|---|
| Pi Zero W | Single 1GHz | 512MB | 1 micro | No | Yes | 40 | $15 |
| Pi 3 B+ | Quad 1.4GHz | 1GB | 4 | Gigabit | Yes | 40 | $35 |
| Pi 4 B | Quad 1.5GHz | 2-8GB | 4 | Gigabit | Yes | 40 | $35-75 |
| Pi 5 | Quad 2.4GHz | 4-8GB | 4 | Gigabit | Yes | 40 | $60-80 |
| Pico | Dual RP2040 | 264KB | 1 micro | No | No | 26 | $4 |
Hardware Overview
Raspberry Pi 4 Board Layout
┌────────────────────────────────────────────────────────┐
│ USB-C Power ┌──────────────┐ │
│ ┌─┐ │ Ethernet │ │
│ └─┘ │ Port │ │
│ └──────────────┘ │
│ ┌────────┐ ┌────────┐ │
│ │ USB │ │ USB │ ┌──────────────┐ │
│ │ 2.0 │ │ 3.0 │ │ Dual HDMI │ │
│ └────────┘ └────────┘ └──────────────┘ │
│ │
│ ┌────────────────────┐ ┌──────┐ │
│ │ BCM2711 SoC │ │Audio │ │
│ │ Quad Cortex-A72 │ │Jack │ │
│ └────────────────────┘ └──────┘ │
│ │
│ ┌──────────────┐ ┌────────────────────────────┐ │
│ │ Micro SD │ │ 40-pin GPIO Header │ │
│ │ Card Slot │ │ │ │
│ └──────────────┘ └────────────────────────────┘ │
│ │
│ [CSI Camera] [DSI Display] │
└────────────────────────────────────────────────────────┘
GPIO Pinout (40-pin Header)
3V3 (1) (2) 5V
GPIO 2/SDA (3) (4) 5V
GPIO 3/SCL (5) (6) GND
GPIO 4 (7) (8) GPIO 14/TXD
GND (9) (10) GPIO 15/RXD
GPIO 17 (11) (12) GPIO 18/PWM
GPIO 27 (13) (14) GND
GPIO 22 (15) (16) GPIO 23
3V3 (17) (18) GPIO 24
GPIO 10/MOSI (19) (20) GND
GPIO 9/MISO (21) (22) GPIO 25
GPIO 11/SCLK (23) (24) GPIO 8/CE0
GND (25) (26) GPIO 7/CE1
GPIO 0 (27) (28) GPIO 1
GPIO 5 (29) (30) GND
GPIO 6 (31) (32) GPIO 12/PWM
GPIO 13/PWM (33) (34) GND
GPIO 19/PWM (35) (36) GPIO 16
GPIO 26 (37) (38) GPIO 20
GND (39) (40) GPIO 21
Power Pins: 3.3V (17mA max per pin), 5V (from USB)
PWM: GPIO 12, 13, 18, 19
SPI0: MOSI(10), MISO(9), SCLK(11), CE0(8), CE1(7)
I2C1: SDA(2), SCL(3)
UART: TXD(14), RXD(15)
Setup and Installation
Initial Setup
1. Download Raspberry Pi OS
# Download Raspberry Pi Imager
# For Ubuntu/Debian:
sudo apt install rpi-imager
# For other systems, download from:
# https://www.raspberrypi.com/software/
2. Flash SD Card
# Using Raspberry Pi Imager (GUI):
# 1. Choose OS: Raspberry Pi OS (32-bit/64-bit)
# 2. Choose SD card
# 3. Click "Write"
# Or using command line (Linux):
sudo dd if=2023-05-03-raspios-bullseye-armhf.img of=/dev/sdX bs=4M status=progress
sync
3. Enable SSH (Headless Setup)
# Create empty 'ssh' file in boot partition
touch /media/username/boot/ssh
# Configure WiFi (optional)
cat > /media/username/boot/wpa_supplicant.conf << EOF
country=US
ctrl_interface=DIR=/var/run/wpa_supplicant GROUP=netdev
update_config=1
network={
ssid="YourNetworkName"
psk="YourPassword"
key_mgmt=WPA-PSK
}
Arduino Programming
Complete guide to Arduino development, from basics to advanced projects.
Table of Contents
- Introduction
- Getting Started
- Arduino Language
- Digital I/O
- Analog I/O
- Serial Communication
- Libraries
- Common Projects
- Advanced Topics
Introduction
Arduino is an open-source electronics platform based on easy-to-use hardware and software. It's designed for artists, designers, hobbyists, and anyone interested in creating interactive objects or environments.
Arduino Boards Comparison
| Board | MCU | Clock | Flash | RAM | Digital I/O | Analog In | Price |
|---|---|---|---|---|---|---|---|
| Uno | ATmega328P | 16 MHz | 32 KB | 2 KB | 14 (6 PWM) | 6 | $ |
| Mega 2560 | ATmega2560 | 16 MHz | 256 KB | 8 KB | 54 (15 PWM) | 16 | $$ |
| Nano | ATmega328P | 16 MHz | 32 KB | 2 KB | 14 (6 PWM) | 8 | $ |
| Leonardo | ATmega32u4 | 16 MHz | 32 KB | 2.5 KB | 20 (7 PWM) | 12 | $ |
| Due | AT91SAM3X8E | 84 MHz | 512 KB | 96 KB | 54 (12 PWM) | 12 | $$$ |
| Nano 33 IoT | SAMD21 | 48 MHz | 256 KB | 32 KB | 14 (11 PWM) | 8 | $$ |
Arduino Uno Pinout
Arduino Uno
┌─────────────┐
│ USB │
├─────────────┤
RESET [ ]──┤ RESET A0 ├──[ ] Analog Input
3.3V [ ]──┤ 3V3 A1 ├──[ ] Analog Input
5V [ ]──┤ 5V A2 ├──[ ] Analog Input
GND [ ]──┤ GND A3 ├──[ ] Analog Input
GND [ ]──┤ GND A4 ├──[ ] Analog Input (I2C SDA)
VIN [ ]──┤ VIN A5 ├──[ ] Analog Input (I2C SCL)
│ │
D0/RX [ ]──┤ 0 13 ├──[ ] D13/SCK (LED_BUILTIN)
D1/TX [ ]──┤ 1 12 ├──[ ] D12/MISO
D2 [ ]──┤ 2 11 ├──[ ] D11~/MOSI
D3~ [ ]──┤ 3 10 ├──[ ] D10~
D4 [ ]──┤ 4 9 ├──[ ] D9~
D5~ [ ]──┤ 5 8 ├──[ ] D8
D6~ [ ]──┤ 6 7 ├──[ ] D7
└─────────────┘
~ = PWM capable
Getting Started
Installation
Arduino IDE
# Download from arduino.cc
# Or use package manager (Linux)
sudo apt install arduino
# Or use Arduino CLI
curl -fsSL https://raw.githubusercontent.com/arduino/arduino-cli/master/install.sh | sh
arduino-cli core update-index
arduino-cli core install arduino:avr
PlatformIO (Recommended for Advanced Users)
pip install platformio
platformio init --board uno
Basic Program Structure
Every Arduino sketch has two required functions:
void setup() {
// Runs once when the board starts
// Initialize pins, serial, libraries
}
void loop() {
// Runs continuously after setup()
// Main program logic goes here
}
First Program: Blink LED
// Blink the built-in LED
void setup() {
pinMode(LED_BUILTIN, OUTPUT); // Set pin 13 as output
}
void loop() {
digitalWrite(LED_BUILTIN, HIGH); // Turn LED on
delay(1000); // Wait 1 second
digitalWrite(LED_BUILTIN, LOW); // Turn LED off
delay(1000); // Wait 1 second
}
Wiring:
Arduino Component
13 ───────────┐
│
┌─┴─┐
│LED│ Built-in LED
└─┬─┘
│
GND ──────────┘
Arduino Language
Data Types
// Boolean
bool flag = true;
// Integers
byte value = 255; // 0-255 (8-bit unsigned)
int temperature = -40; // -32768 to 32767 (16-bit signed)
unsigned int count = 65535; // 0-65535 (16-bit unsigned)
long distance = 1000000L; // 32-bit signed
unsigned long time = millis(); // 32-bit unsigned
// Floating Point
float voltage = 3.3; // 32-bit, ~7 digits precision
double precise = 3.14159; // Same as float on Arduino
// Characters and Strings
char letter = 'A';
char message[] = "Hello"; // C-style string
String text = "World"; // Arduino String class
// Arrays
int readings[10]; // Array of 10 integers
int values[] = {1, 2, 3}; // Initialized array
Control Structures
// If-else
if (temperature > 30) {
digitalWrite(FAN_PIN, HIGH);
} else if (temperature > 20) {
analogWrite(FAN_PIN, 128);
} else {
digitalWrite(FAN_PIN, LOW);
}
// Switch-case
switch (state) {
case 0:
// Do something
break;
case 1:
// Do something else
break;
default:
// Default action
break;
}
// For loop
for (int i = 0; i < 10; i++) {
Serial.println(i);
}
// While loop
while (digitalRead(BUTTON_PIN) == HIGH) {
// Wait for button release
}
// Do-while loop
do {
value = analogRead(A0);
} while (value < 512);
Functions
// Function declaration
int addNumbers(int a, int b);
void setup() {
Serial.begin(9600);
int result = addNumbers(5, 3);
Serial.println(result); // Prints 8
}
// Function definition
int addNumbers(int a, int b) {
return a + b;
}
// Function with default parameters
void blinkLED(int pin, int times = 1, int delayTime = 500) {
for (int i = 0; i < times; i++) {
digitalWrite(pin, HIGH);
delay(delayTime);
digitalWrite(pin, LOW);
delay(delayTime);
}
}
void loop() {
blinkLED(13); // Blink once
blinkLED(13, 3); // Blink 3 times
blinkLED(13, 5, 200); // Blink 5 times with 200ms delay
}
Digital I/O
Basic Digital Functions
pinMode(pin, mode); // Configure pin: INPUT, OUTPUT, INPUT_PULLUP
digitalWrite(pin, value); // Write HIGH or LOW
int value = digitalRead(pin); // Read HIGH or LOW
LED Control
// Simple LED control
const int LED_PIN = 9;
void setup() {
pinMode(LED_PIN, OUTPUT);
}
void loop() {
digitalWrite(LED_PIN, HIGH);
delay(500);
digitalWrite(LED_PIN, LOW);
delay(500);
}
Wiring:
Arduino Component
9 ───────────┬─────────┐
│ │
┌─┴─┐ ┌─┴─┐
│220│ │LED│
│Ω │ │ > │
└─┬─┘ └─┬─┘
│ │
GND ──────────┴─────────┘
Button Input
const int BUTTON_PIN = 2;
const int LED_PIN = 13;
void setup() {
pinMode(BUTTON_PIN, INPUT_PULLUP); // Internal pull-up resistor
pinMode(LED_PIN, OUTPUT);
}
void loop() {
int buttonState = digitalRead(BUTTON_PIN);
if (buttonState == LOW) { // Button pressed (active LOW)
digitalWrite(LED_PIN, HIGH);
} else {
digitalWrite(LED_PIN, LOW);
}
}
Wiring:
Arduino Button
2 ────────┬────┬──┬──── 5V (optional if using INPUT_PULLUP)
│ │ │
┌─┴─┐ ┌┴──┴┐
│10k│ │BTN │
│Ω │ └────┘
└─┬─┘ │
│ │
GND ───────┴──────┘
Debouncing
const int BUTTON_PIN = 2;
const int LED_PIN = 13;
const int DEBOUNCE_DELAY = 50; // milliseconds
int lastButtonState = HIGH;
int buttonState = HIGH;
unsigned long lastDebounceTime = 0;
bool ledState = false;
void setup() {
pinMode(BUTTON_PIN, INPUT_PULLUP);
pinMode(LED_PIN, OUTPUT);
}
void loop() {
int reading = digitalRead(BUTTON_PIN);
// If the switch changed, due to noise or pressing
if (reading != lastButtonState) {
lastDebounceTime = millis();
}
if ((millis() - lastDebounceTime) > DEBOUNCE_DELAY) {
// If the button state has changed
if (reading != buttonState) {
buttonState = reading;
// Only toggle if the new button state is LOW (pressed)
if (buttonState == LOW) {
ledState = !ledState;
digitalWrite(LED_PIN, ledState);
}
}
}
lastButtonState = reading;
}
Analog I/O
Analog Input (ADC)
analogRead(pin); // Read analog value (0-1023)
Reading a Potentiometer
const int POT_PIN = A0;
const int LED_PIN = 9;
void setup() {
Serial.begin(9600);
pinMode(LED_PIN, OUTPUT);
}
void loop() {
int potValue = analogRead(POT_PIN); // 0-1023
// Convert to voltage (0-5V)
float voltage = potValue * (5.0 / 1023.0);
// Convert to LED brightness (0-255)
int brightness = map(potValue, 0, 1023, 0, 255);
Serial.print("Value: ");
Serial.print(potValue);
Serial.print(" Voltage: ");
Serial.print(voltage);
Serial.println("V");
analogWrite(LED_PIN, brightness);
delay(100);
}
Wiring:
Potentiometer Arduino
┌────┐
5V─┤1 3├─GND
│ │
│ 2 ├─A0 (wiper)
└────┘
Analog Output (PWM)
analogWrite(pin, value); // PWM output (0-255)
Fading LED
const int LED_PIN = 9; // Must be PWM pin (~)
void setup() {
pinMode(LED_PIN, OUTPUT);
}
void loop() {
// Fade in
for (int brightness = 0; brightness <= 255; brightness++) {
analogWrite(LED_PIN, brightness);
delay(10);
}
// Fade out
for (int brightness = 255; brightness >= 0; brightness--) {
analogWrite(LED_PIN, brightness);
delay(10);
}
}
Temperature Sensor (LM35)
const int TEMP_PIN = A0;
void setup() {
Serial.begin(9600);
}
void loop() {
int reading = analogRead(TEMP_PIN);
// Convert to voltage (0-5V)
float voltage = reading * (5.0 / 1023.0);
// Convert to temperature (LM35: 10mV per degree C)
float temperatureC = voltage * 100.0;
float temperatureF = (temperatureC * 9.0 / 5.0) + 32.0;
Serial.print("Temperature: ");
Serial.print(temperatureC);
Serial.print("°C / ");
Serial.print(temperatureF);
Serial.println("°F");
delay(1000);
}
Wiring:
LM35 Sensor Arduino
┌────┐
1 ─┤ VS ├─ 5V
│ │
2 ─┤Vout├─ A0
│ │
3 ─┤GND ├─ GND
└────┘
Serial Communication
Basic Serial Functions
Serial.begin(baudrate); // Initialize serial (9600, 115200, etc.)
Serial.print(data); // Print without newline
Serial.println(data); // Print with newline
Serial.write(byte); // Send raw byte
int available = Serial.available(); // Bytes available to read
char c = Serial.read(); // Read one byte
String line = Serial.readStringUntil('\n'); // Read until newline
Serial Monitor Output
void setup() {
Serial.begin(9600);
Serial.println("Arduino Ready!");
}
void loop() {
int sensorValue = analogRead(A0);
// Different formatting options
Serial.print("Sensor: ");
Serial.println(sensorValue);
Serial.print("Hex: 0x");
Serial.println(sensorValue, HEX);
Serial.print("Binary: 0b");
Serial.println(sensorValue, BIN);
Serial.print("Float: ");
float voltage = sensorValue * (5.0 / 1023.0);
Serial.println(voltage, 2); // 2 decimal places
delay(1000);
}
Serial Input
String inputString = "";
bool stringComplete = false;
void setup() {
Serial.begin(9600);
inputString.reserve(200); // Reserve space for efficiency
}
void loop() {
// Check if data is available
while (Serial.available()) {
char inChar = (char)Serial.read();
inputString += inChar;
if (inChar == '\n') {
stringComplete = true;
}
}
// Process complete command
if (stringComplete) {
Serial.print("Received: ");
Serial.println(inputString);
// Process command
if (inputString.startsWith("LED ON")) {
digitalWrite(LED_BUILTIN, HIGH);
Serial.println("LED turned ON");
} else if (inputString.startsWith("LED OFF")) {
digitalWrite(LED_BUILTIN, LOW);
Serial.println("LED turned OFF");
}
// Clear the string
inputString = "";
stringComplete = false;
}
}
Libraries
Built-in Libraries
Wire (I2C)
#include <Wire.h>
void setup() {
Wire.begin(); // Join I2C bus as master
}
void loop() {
// Read from I2C device at address 0x68
Wire.beginTransmission(0x68);
Wire.write(0x00); // Register address
Wire.endTransmission();
Wire.requestFrom(0x68, 1); // Request 1 byte
if (Wire.available()) {
byte data = Wire.read();
}
}
SPI
#include <SPI.h>
const int CS_PIN = 10;
void setup() {
SPI.begin();
pinMode(CS_PIN, OUTPUT);
digitalWrite(CS_PIN, HIGH);
}
void loop() {
digitalWrite(CS_PIN, LOW); // Select device
SPI.transfer(0xAB); // Send byte
byte received = SPI.transfer(0x00); // Receive byte
digitalWrite(CS_PIN, HIGH); // Deselect device
}
EEPROM
#include <EEPROM.h>
void setup() {
// Write byte to EEPROM
EEPROM.write(0, 42);
// Read byte from EEPROM
byte value = EEPROM.read(0);
// Update (only writes if different)
EEPROM.update(0, 42);
// Write/read other types
int address = 0;
float f = 3.14;
EEPROM.put(address, f);
EEPROM.get(address, f);
}
Popular External Libraries
Servo Control
#include <Servo.h>
Servo myServo;
const int SERVO_PIN = 9;
void setup() {
myServo.attach(SERVO_PIN);
}
void loop() {
// Sweep from 0 to 180 degrees
for (int pos = 0; pos <= 180; pos++) {
myServo.write(pos);
delay(15);
}
// Sweep back
for (int pos = 180; pos >= 0; pos--) {
myServo.write(pos);
delay(15);
}
}
Wiring:
Servo Motor Arduino
┌────┐
R ─┤Red ├─ 5V (or external)
B ─┤Brn ├─ GND
O ─┤Org ├─ Pin 9 (PWM)
└────┘
LiquidCrystal (LCD Display)
#include <LiquidCrystal.h>
// RS, E, D4, D5, D6, D7
LiquidCrystal lcd(12, 11, 5, 4, 3, 2);
void setup() {
lcd.begin(16, 2); // 16x2 LCD
lcd.print("Hello, World!");
}
void loop() {
lcd.setCursor(0, 1); // Column 0, Row 1
lcd.print(millis() / 1000);
lcd.print("s");
delay(100);
}
Wiring:
LCD 16x2 Arduino
VSS ────────────── GND
VDD ────────────── 5V
V0 ────────────── Potentiometer (contrast)
RS ────────────── 12
RW ────────────── GND
E ────────────── 11
D4 ────────────── 5
D5 ────────────── 4
D6 ────────────── 3
D7 ────────────── 2
A ────────────── 5V (backlight)
K ────────────── GND
DHT Temperature/Humidity Sensor
#include <DHT.h>
#define DHTPIN 2
#define DHTTYPE DHT11 // or DHT22
DHT dht(DHTPIN, DHTTYPE);
void setup() {
Serial.begin(9600);
dht.begin();
}
void loop() {
float humidity = dht.readHumidity();
float temperature = dht.readTemperature(); // Celsius
float temperatureF = dht.readTemperature(true); // Fahrenheit
if (isnan(humidity) || isnan(temperature)) {
Serial.println("Failed to read from DHT sensor!");
return;
}
Serial.print("Humidity: ");
Serial.print(humidity);
Serial.print("% Temperature: ");
Serial.print(temperature);
Serial.println("°C");
delay(2000); // DHT11 minimum sampling period
}
Common Projects
Project 1: Traffic Light
const int RED_LED = 10;
const int YELLOW_LED = 9;
const int GREEN_LED = 8;
void setup() {
pinMode(RED_LED, OUTPUT);
pinMode(YELLOW_LED, OUTPUT);
pinMode(GREEN_LED, OUTPUT);
}
void loop() {
// Green light
digitalWrite(GREEN_LED, HIGH);
delay(5000); // 5 seconds
digitalWrite(GREEN_LED, LOW);
// Yellow light
digitalWrite(YELLOW_LED, HIGH);
delay(2000); // 2 seconds
digitalWrite(YELLOW_LED, LOW);
// Red light
digitalWrite(RED_LED, HIGH);
delay(5000); // 5 seconds
digitalWrite(RED_LED, LOW);
}
Wiring:
Arduino LEDs
10 ───┬───[220Ω]───[RED LED]───GND
│
9 ───┼───[220Ω]───[YEL LED]───GND
│
8 ───┴───[220Ω]───[GRN LED]───GND
Project 2: Ultrasonic Distance Sensor
const int TRIG_PIN = 9;
const int ECHO_PIN = 10;
void setup() {
Serial.begin(9600);
pinMode(TRIG_PIN, OUTPUT);
pinMode(ECHO_PIN, INPUT);
}
void loop() {
// Send ultrasonic pulse
digitalWrite(TRIG_PIN, LOW);
delayMicroseconds(2);
digitalWrite(TRIG_PIN, HIGH);
delayMicroseconds(10);
digitalWrite(TRIG_PIN, LOW);
// Measure echo duration
long duration = pulseIn(ECHO_PIN, HIGH);
// Calculate distance in cm
// Speed of sound: 343 m/s = 0.0343 cm/µs
// Distance = (duration / 2) * 0.0343
float distance = duration * 0.0343 / 2;
Serial.print("Distance: ");
Serial.print(distance);
Serial.println(" cm");
delay(100);
}
Wiring:
HC-SR04 Arduino
VCC ─────────────── 5V
Trig ────────────── 9
Echo ────────────── 10
GND ─────────────── GND
Project 3: Light-Activated Switch
const int LDR_PIN = A0;
const int LED_PIN = 13;
const int THRESHOLD = 500; // Adjust based on lighting
void setup() {
Serial.begin(9600);
pinMode(LED_PIN, OUTPUT);
}
void loop() {
int lightLevel = analogRead(LDR_PIN);
Serial.print("Light Level: ");
Serial.println(lightLevel);
if (lightLevel < THRESHOLD) {
digitalWrite(LED_PIN, HIGH); // Turn on LED when dark
} else {
digitalWrite(LED_PIN, LOW); // Turn off LED when bright
}
delay(100);
}
Wiring:
5V
│
┌─┴─┐
│LDR│ (Light Dependent Resistor)
└─┬─┘
├────── A0
┌─┴─┐
│10k│ (Pull-down resistor)
│Ω │
└─┬─┘
│
GND
Project 4: Temperature-Controlled Fan
#include <DHT.h>
#define DHTPIN 2
#define DHTTYPE DHT11
#define FAN_PIN 9 // PWM pin for fan control
DHT dht(DHTPIN, DHTTYPE);
const float TEMP_MIN = 25.0; // Start fan at 25°C
const float TEMP_MAX = 35.0; // Full speed at 35°C
void setup() {
Serial.begin(9600);
pinMode(FAN_PIN, OUTPUT);
dht.begin();
}
void loop() {
float temperature = dht.readTemperature();
if (isnan(temperature)) {
Serial.println("Failed to read temperature!");
return;
}
// Calculate fan speed (0-255)
int fanSpeed = 0;
if (temperature < TEMP_MIN) {
fanSpeed = 0;
} else if (temperature > TEMP_MAX) {
fanSpeed = 255;
} else {
fanSpeed = map(temperature * 10, TEMP_MIN * 10, TEMP_MAX * 10, 0, 255);
}
analogWrite(FAN_PIN, fanSpeed);
Serial.print("Temperature: ");
Serial.print(temperature);
Serial.print("°C Fan Speed: ");
Serial.print((fanSpeed * 100) / 255);
Serial.println("%");
delay(2000);
}
Project 5: Simple Data Logger
#include <SD.h>
#include <SPI.h>
const int CS_PIN = 10;
const int SENSOR_PIN = A0;
File dataFile;
void setup() {
Serial.begin(9600);
// Initialize SD card
if (!SD.begin(CS_PIN)) {
Serial.println("SD card initialization failed!");
return;
}
Serial.println("SD card initialized.");
}
void loop() {
int sensorValue = analogRead(SENSOR_PIN);
float voltage = sensorValue * (5.0 / 1023.0);
// Open file for writing
dataFile = SD.open("datalog.txt", FILE_WRITE);
if (dataFile) {
// Write timestamp and value
dataFile.print(millis());
dataFile.print(",");
dataFile.println(voltage);
dataFile.close();
Serial.print("Logged: ");
Serial.println(voltage);
} else {
Serial.println("Error opening file!");
}
delay(1000); // Log every second
}
Advanced Topics
Interrupts
const int BUTTON_PIN = 2; // Must be interrupt-capable pin
const int LED_PIN = 13;
volatile bool ledState = false;
void setup() {
pinMode(BUTTON_PIN, INPUT_PULLUP);
pinMode(LED_PIN, OUTPUT);
// Attach interrupt: pin, ISR function, trigger mode
attachInterrupt(digitalPinToInterrupt(BUTTON_PIN), buttonISR, FALLING);
}
void loop() {
// Main loop can do other things
// LED toggle happens immediately when button pressed
}
// Interrupt Service Routine (keep it short!)
void buttonISR() {
ledState = !ledState;
digitalWrite(LED_PIN, ledState);
}
Timers
unsigned long previousMillis = 0;
const long interval = 1000; // 1 second
void setup() {
pinMode(LED_BUILTIN, OUTPUT);
}
void loop() {
unsigned long currentMillis = millis();
// Non-blocking timing
if (currentMillis - previousMillis >= interval) {
previousMillis = currentMillis;
// Toggle LED
digitalWrite(LED_BUILTIN, !digitalRead(LED_BUILTIN));
}
// Can do other things here
}
Memory Optimization
// Store strings in flash memory (PROGMEM)
const char message[] PROGMEM = "This string is stored in flash";
void setup() {
Serial.begin(9600);
// Read from flash memory
char buffer[50];
strcpy_P(buffer, message);
Serial.println(buffer);
}
// Use F() macro for Serial.print
void loop() {
Serial.println(F("This uses flash memory, not RAM"));
delay(1000);
}
Low Power Mode
#include <avr/sleep.h>
#include <avr/power.h>
void setup() {
pinMode(LED_BUILTIN, OUTPUT);
pinMode(2, INPUT_PULLUP);
// Enable interrupt for wake-up
attachInterrupt(digitalPinToInterrupt(2), wakeUp, LOW);
}
void loop() {
digitalWrite(LED_BUILTIN, HIGH);
delay(1000);
digitalWrite(LED_BUILTIN, LOW);
// Enter sleep mode
enterSleep();
}
void enterSleep() {
set_sleep_mode(SLEEP_MODE_PWR_DOWN);
sleep_enable();
// Disable peripherals
power_adc_disable();
power_spi_disable();
power_timer0_disable();
power_timer1_disable();
power_timer2_disable();
power_twi_disable();
sleep_mode(); // Sleep here
// Wake up here
sleep_disable();
power_all_enable();
}
void wakeUp() {
// ISR to wake up
}
Best Practices
1. Avoid delay() for Responsive Code
// Bad: Blocking
void loop() {
digitalWrite(LED1, HIGH);
delay(1000);
digitalWrite(LED2, HIGH);
delay(500);
}
// Good: Non-blocking
unsigned long led1Time = 0;
unsigned long led2Time = 0;
void loop() {
unsigned long now = millis();
if (now - led1Time >= 1000) {
digitalWrite(LED1, !digitalRead(LED1));
led1Time = now;
}
if (now - led2Time >= 500) {
digitalWrite(LED2, !digitalRead(LED2));
led2Time = now;
}
}
2. Use const for Pin Definitions
// Good: Easy to change and read
const int LED_PIN = 13;
const int BUTTON_PIN = 2;
const int SENSOR_PIN = A0;
3. Check Return Values
if (!SD.begin(CS_PIN)) {
Serial.println("SD card failed!");
while (1); // Halt
}
4. Use Meaningful Variable Names
// Bad
int x = analogRead(A0);
// Good
int lightLevel = analogRead(LIGHT_SENSOR_PIN);
5. Comment Complex Logic
// Calculate distance from ultrasonic sensor
// Formula: distance (cm) = duration (µs) × 0.0343 / 2
// Division by 2 accounts for round-trip time
float distance = duration * 0.0343 / 2;
Troubleshooting
Common Issues
-
Upload Failed
- Check correct board and port selected
- Try pressing reset button before upload
- Close Serial Monitor during upload
-
Serial Monitor Shows Garbage
- Check baud rate matches code
- Verify USB cable supports data (not just power)
-
Sketch Too Large
- Remove unused libraries
- Use PROGMEM for strings
- Optimize code
-
Unexpected Behavior
- Add Serial.println() for debugging
- Check wiring and connections
- Verify power supply adequate
Resources
- Official Documentation: https://www.arduino.cc/reference/
- Forum: https://forum.arduino.cc/
- Project Hub: https://create.arduino.cc/projecthub
- Libraries: https://www.arduinolibraries.info/
See Also
- ESP32 - More powerful Arduino-compatible platform
- AVR Programming - Low-level AVR microcontroller programming
- GPIO - Digital I/O concepts
- UART - Serial communication details
- SPI - SPI protocol
- I2C - I2C protocol
SPI (Serial Peripheral Interface)
Overview
SPI (Serial Peripheral Interface) is a synchronous serial communication protocol used for short-distance communication between microcontrollers and peripheral devices like sensors, displays, SD cards, and flash memory. Developed by Motorola in the 1980s, SPI is known for its high-speed, full-duplex communication capabilities.
Key Features
- Full-Duplex Communication: SPI can send and receive data simultaneously on separate lines
- High Speed: Typically operates at speeds from 1 MHz to over 50 MHz
- Master-Slave Architecture: Always has one master device controlling one or more slave devices
- Four-Wire Interface: Uses separate lines for clock, data in, data out, and chip select
- No Addressing: Slave selection is done via dedicated chip select lines
Signal Lines
SPI uses four main signal lines:
| Signal | Alternative Names | Description |
|---|---|---|
| SCLK | SCK, CLK | Serial Clock - Generated by master to synchronize data transfer |
| MOSI | SDO, DO, SIMO | Master Out Slave In - Data from master to slave |
| MISO | SDI, DI, SOMI | Master In Slave Out - Data from slave to master |
| SS | CS, NSS | Slave Select/Chip Select - Selects which slave is active |
Why Four Wires?
Unlike I2C's two-wire design, SPI uses separate data lines for sending and receiving, enabling full-duplex communication. Each slave device needs its own chip select line, which can increase pin count when multiple slaves are used.
How It Works
Basic Communication Flow
- Master selects slave: Pulls the slave's CS line LOW (active)
- Master generates clock: Starts toggling the SCLK line
- Data exchange: On each clock cycle:
- Master shifts data out on MOSI
- Slave shifts data out on MISO
- Both shift data in simultaneously
- Master deselects slave: Pulls CS line HIGH (inactive)
Clock Polarity and Phase (CPOL/CPHA)
SPI has four modes determined by two settings:
-
CPOL (Clock Polarity): Determines the idle state of the clock
- CPOL = 0: Clock idles LOW
- CPOL = 1: Clock idles HIGH
-
CPHA (Clock Phase): Determines when data is sampled
- CPHA = 0: Data sampled on leading edge, shifted on trailing edge
- CPHA = 1: Data sampled on trailing edge, shifted on leading edge
| Mode | CPOL | CPHA | Clock Idle | Data Sampled |
|---|---|---|---|---|
| 0 | 0 | 0 | LOW | Leading (rising) edge |
| 1 | 0 | 1 | LOW | Trailing (falling) edge |
| 2 | 1 | 0 | HIGH | Leading (falling) edge |
| 3 | 1 | 1 | HIGH | Trailing (rising) edge |
Important: Master and slave must use the same mode for successful communication!
Code Examples
Arduino SPI Communication
#include <SPI.h>
const int chipSelectPin = 10;
void setup() {
// Initialize SPI pins
pinMode(chipSelectPin, OUTPUT);
digitalWrite(chipSelectPin, HIGH); // Deselect slave initially
// Initialize SPI library
SPI.begin();
// Configure SPI settings
// Max speed: 4 MHz, MSB first, Mode 0
SPI.beginTransaction(SPISettings(4000000, MSBFIRST, SPI_MODE0));
}
void loop() {
// Select slave device
digitalWrite(chipSelectPin, LOW);
// Send a byte and receive response simultaneously
byte command = 0x3A; // Example command
byte response = SPI.transfer(command);
// Send multiple bytes
byte dataToSend[] = {0x01, 0x02, 0x03};
for (int i = 0; i < 3; i++) {
byte receivedByte = SPI.transfer(dataToSend[i]);
}
// Deselect slave
digitalWrite(chipSelectPin, HIGH);
delay(1000);
}
STM32 HAL SPI Example
#include "stm32f4xx_hal.h"
SPI_HandleTypeDef hspi1;
void SPI_Init(void) {
hspi1.Instance = SPI1;
hspi1.Init.Mode = SPI_MODE_MASTER;
hspi1.Init.Direction = SPI_DIRECTION_2LINES;
hspi1.Init.DataSize = SPI_DATASIZE_8BIT;
hspi1.Init.CLKPolarity = SPI_POLARITY_LOW; // CPOL = 0
hspi1.Init.CLKPhase = SPI_PHASE_1EDGE; // CPHA = 0
hspi1.Init.NSS = SPI_NSS_SOFT;
hspi1.Init.BaudRatePrescaler = SPI_BAUDRATEPRESCALER_16;
hspi1.Init.FirstBit = SPI_FIRSTBIT_MSB;
HAL_SPI_Init(&hspi1);
}
void SPI_Write_Read(uint8_t *txData, uint8_t *rxData, uint16_t size) {
HAL_GPIO_WritePin(GPIOA, GPIO_PIN_4, GPIO_PIN_RESET); // CS LOW
HAL_SPI_TransmitReceive(&hspi1, txData, rxData, size, 100);
HAL_GPIO_WritePin(GPIOA, GPIO_PIN_4, GPIO_PIN_SET); // CS HIGH
}
ESP32 SPI Example
#include <SPI.h>
SPIClass spi(HSPI); // Use HSPI bus
const int CS_PIN = 15;
void setup() {
spi.begin(14, 12, 13, CS_PIN); // SCLK, MISO, MOSI, SS
pinMode(CS_PIN, OUTPUT);
digitalWrite(CS_PIN, HIGH);
}
uint8_t readRegister(uint8_t reg) {
digitalWrite(CS_PIN, LOW);
spi.transfer(reg | 0x80); // Read bit set
uint8_t value = spi.transfer(0x00); // Dummy byte to read
digitalWrite(CS_PIN, HIGH);
return value;
}
void writeRegister(uint8_t reg, uint8_t value) {
digitalWrite(CS_PIN, LOW);
spi.transfer(reg & 0x7F); // Write bit clear
spi.transfer(value);
digitalWrite(CS_PIN, HIGH);
}
Common Use Cases
1. SD Card Communication
- High-speed file I/O
- Supports both SPI and SDIO modes
- Ideal for data logging applications
2. Display Interfaces
- TFT LCD displays (ILI9341, ST7735)
- OLED displays
- E-ink displays
- Fast refresh rates for graphics
3. Sensor Communication
- Digital accelerometers (ADXL345)
- Gyroscopes (MPU6050 in SPI mode)
- Pressure sensors (BMP280)
4. Memory Devices
- Flash memory (W25Q128)
- EEPROM chips
- FRAM (Ferroelectric RAM)
5. Wireless Modules
- NRF24L01+ (2.4GHz transceiver)
- LoRa modules (SX1278)
- WiFi modules
SPI vs I2C Comparison
| Feature | SPI | I2C |
|---|---|---|
| Wires | 4 (+ 1 per additional slave) | 2 |
| Speed | Up to 50+ MHz | Up to 3.4 MHz |
| Duplex | Full-duplex | Half-duplex |
| Addressing | Hardware (CS pins) | Software (7/10-bit addresses) |
| Distance | Short (< 1 meter) | Short (< 1 meter) |
| Complexity | Simple protocol | More complex protocol |
| Multi-master | No (typically) | Yes |
| Pins required | Increases with slaves | Constant |
Best Practices
1. Wire Length and Speed
- Keep wires short (< 30cm) for high speeds
- Reduce speed for longer connections
- Use twisted pairs for MOSI/MISO on longer runs
2. Pull-up Resistors
- MISO should have a pull-up resistor (~10k ohm)
- Prevents floating when no slave is selected
- Some slave devices have built-in pull-ups
3. Chip Select Management
// Always wrap transfers in CS control
digitalWrite(CS, LOW);
// ... SPI operations ...
digitalWrite(CS, HIGH);
// For critical timing, disable interrupts
noInterrupts();
digitalWrite(CS, LOW);
SPI.transfer(data);
digitalWrite(CS, HIGH);
interrupts();
4. Power Considerations
- Ensure all devices share a common ground
- Check voltage levels (3.3V vs 5V)
- Use level shifters if needed
5. Multiple Slave Devices
// Daisy chain method (saves CS pins)
// Data flows through all slaves
digitalWrite(CS_SHARED, LOW);
SPI.transfer(dataForSlave1);
SPI.transfer(dataForSlave2);
SPI.transfer(dataForSlave3);
digitalWrite(CS_SHARED, HIGH);
// Individual CS method (parallel addressing)
digitalWrite(CS_SLAVE1, LOW);
SPI.transfer(dataForSlave1);
digitalWrite(CS_SLAVE1, HIGH);
Common Issues and Debugging
Problem: No Data Received
- Check CPOL/CPHA mode matches between master and slave
- Verify wiring (MOSI to MOSI, MISO to MISO)
- Ensure CS is properly toggled
- Check clock frequency is within slave's range
Problem: Corrupted Data
- Reduce SPI clock speed
- Check for loose connections
- Add small capacitors (10nF) near slave devices
- Ensure proper ground connections
Problem: Multiple Slaves Interfering
- Verify only one CS is active at a time
- Check for proper tri-state behavior on MISO
- Add pull-up on MISO line
ELI10 (Explain Like I'm 10)
Imagine you're playing a game with your friend where you both pass notes at the same time:
- The master is like the teacher who decides when to pass notes (controls the clock)
- MOSI is the note you pass to your friend
- MISO is the note your friend passes back to you
- Chip Select is like tapping your friend's shoulder to say "Hey, I'm talking to you!"
- Clock is like a metronome that keeps everyone in sync - you both write and read at the same time
The cool part? You can both write notes to each other at the exact same time! That's why SPI is called "full-duplex" - it's like talking and listening simultaneously.
Further Resources
- SPI Wikipedia - Detailed technical overview
- Analog Devices SPI Tutorial
- SparkFun SPI Tutorial
- Arduino SPI Library Reference
- Application notes from your microcontroller manufacturer
I2C
Overview
I2C (Inter-Integrated Circuit) is a synchronous, multi-master, multi-slave, packet-switched, single-ended, serial communication bus. It was developed by Philips Semiconductor (now NXP Semiconductors) in the 1980s to facilitate communication between integrated circuits on a single board.
Key Features
- Multi-Master Configuration: I2C allows multiple master devices to control the bus, enabling more complex communication scenarios.
- Two-Wire Interface: I2C uses only two wires for communication: the Serial Data Line (SDA) and the Serial Clock Line (SCL). This simplicity reduces the number of connections required.
- Addressing: Each device on the I2C bus has a unique address, allowing the master to communicate with specific slaves.
- Speed: I2C supports different data rates, typically 100 kbit/s (Standard Mode), 400 kbit/s (Fast Mode), and up to 3.4 Mbit/s (High-Speed Mode).
Applications
I2C is widely used in various applications, including:
- Sensor Communication: Many sensors, such as temperature, humidity, and accelerometers, use I2C to communicate with microcontrollers.
- Display Interfaces: LCD and OLED displays often utilize I2C for data transfer, simplifying the wiring and control.
- Memory Devices: EEPROMs and other memory devices frequently implement I2C for data storage and retrieval.
Signals
In the context of I2C, signals refer to the electrical signals used for communication between the master and slave devices on the bus. The key signals in the I2C interface include:
-
SDA (Serial Data Line): This line carries the data being transmitted between devices. It is bidirectional, allowing both the master and slave devices to send and receive data.
-
SCL (Serial Clock Line): This line provides the clock signal that synchronizes the data transfer between the master and slave devices. The master device generates the clock signal, ensuring that both devices are in sync during communication.
-
Start Condition: This is a specific signal generated by the master to indicate the beginning of a data transmission. It is represented by a high-to-low transition on the SDA line while the SCL line is high.
-
Stop Condition: This signal indicates the end of a data transmission. It is represented by a low-to-high transition on the SDA line while the SCL line is high.
-
Acknowledgment (ACK): After each byte of data is transmitted, the receiving device sends an acknowledgment signal back to the sender. This is done by pulling the SDA line low during the ninth clock pulse.
-
No Acknowledgment (NACK): If the receiving device does not acknowledge the received data, it will leave the SDA line high during the ninth clock pulse, indicating that the sender should stop transmitting.
These signals are essential for establishing communication, ensuring data integrity, and managing the flow of information between devices on the I2C bus.
Conclusion
I2C is a versatile and efficient communication protocol that is essential in embedded systems and electronic devices. Its simplicity and flexibility make it a popular choice for connecting various components in a wide range of applications.
UART (Universal Asynchronous Receiver-Transmitter)
Overview
UART is one of the most commonly used serial communication protocols in embedded systems. Unlike SPI and I2C, UART is asynchronous - meaning it doesn't require a shared clock signal between devices. This makes it simple, robust, and perfect for point-to-point communication between two devices.
Key Features
- Asynchronous: No shared clock required - devices use pre-agreed baud rates
- Point-to-Point: Communication between exactly two devices
- Two-Wire Interface: Only TX (transmit) and RX (receive) lines needed
- Full-Duplex: Can send and receive data simultaneously
- Simple: Easy to implement and debug
- Universal: Supported by virtually all microcontrollers
Signal Lines
UART uses only two main signal lines (plus ground):
| Signal | Description |
|---|---|
| TX | Transmit Data - Output from device |
| RX | Receive Data - Input to device |
| GND | Common ground reference |
Important Wiring: TX of device A connects to RX of device B, and vice versa!
Device A Device B
TX --------> RX
RX <-------- TX
GND --------- GND
How It Works
Data Frame Structure
A typical UART data frame consists of:
Start Data Bits Parity Stop
Bit (5-9) (Opt) Bit(s)
, , , , , , ,
0 1 2 3 4 5 6 7
- Idle State: Line is HIGH when no data is being sent
- Start Bit: Single LOW bit signals beginning of frame
- Data Bits: 5-9 bits of actual data (usually 8 bits)
- Parity Bit (Optional): Error checking bit
- Stop Bit(s): 1, 1.5, or 2 HIGH bits signal end of frame
Baud Rate
Baud rate is the speed of communication, measured in bits per second (bps).
Common Baud Rates:
- 9600 bps - Default for many applications
- 19200 bps
- 38400 bps
- 57600 bps
- 115200 bps - Common for debugging/logging
- 230400 bps
- 921600 bps - High-speed applications
Formula:
Bit Duration = 1 / Baud Rate
At 9600 baud: each bit takes ~104 microseconds
Parity Bit
Parity is a simple error detection method:
- Even Parity: Parity bit set so total number of 1s is even
- Odd Parity: Parity bit set so total number of 1s is odd
- None: No parity bit (most common)
Configuration Format
UART settings are often written as: Baud-Data-Parity-Stop
Examples:
9600-8-N-1: 9600 baud, 8 data bits, No parity, 1 stop bit (most common)115200-8-E-1: 115200 baud, 8 data bits, Even parity, 1 stop bit
Code Examples
Arduino UART
void setup() {
// Initialize Serial (UART0) at 9600 baud
Serial.begin(9600);
// For other UART ports on boards like Arduino Mega:
// Serial1.begin(115200);
// Serial2.begin(9600);
// Wait for serial port to connect
while (!Serial) {
; // Wait for serial port to connect (needed for native USB)
}
Serial.println("UART initialized!");
}
void loop() {
// Sending data
Serial.print("Temperature: ");
Serial.println(25.5);
// Sending formatted data
char buffer[50];
sprintf(buffer, "Value: %d, Time: %lu", 42, millis());
Serial.println(buffer);
// Reading data
if (Serial.available() > 0) {
// Read a single byte
char incoming = Serial.read();
// Read until newline
String command = Serial.readStringUntil('\n');
// Read with timeout (default 1000ms)
Serial.setTimeout(500);
int value = Serial.parseInt();
Serial.print("Received: ");
Serial.println(command);
}
delay(1000);
}
ESP32 Multiple UARTs
// ESP32 has 3 hardware UARTs
HardwareSerial SerialGPS(1); // UART1
HardwareSerial SerialModem(2); // UART2
void setup() {
// Serial0 (USB) - default pins
Serial.begin(115200);
// UART1 - custom pins (TX=17, RX=16)
SerialGPS.begin(9600, SERIAL_8N1, 16, 17);
// UART2 - custom pins (TX=25, RX=26)
SerialModem.begin(115200, SERIAL_8N1, 26, 25);
}
void loop() {
// Read from GPS on UART1
if (SerialGPS.available()) {
String gpsData = SerialGPS.readStringUntil('\n');
Serial.println("GPS: " + gpsData);
}
// Read from modem on UART2
if (SerialModem.available()) {
String modemResponse = SerialModem.readStringUntil('\n');
Serial.println("Modem: " + modemResponse);
}
}
STM32 HAL UART
#include "stm32f4xx_hal.h"
UART_HandleTypeDef huart2;
void UART_Init(void) {
huart2.Instance = USART2;
huart2.Init.BaudRate = 115200;
huart2.Init.WordLength = UART_WORDLENGTH_8B;
huart2.Init.StopBits = UART_STOPBITS_1;
huart2.Init.Parity = UART_PARITY_NONE;
huart2.Init.Mode = UART_MODE_TX_RX;
huart2.Init.HwFlowCtl = UART_HWCONTROL_NONE;
huart2.Init.OverSampling = UART_OVERSAMPLING_16;
HAL_UART_Init(&huart2);
}
// Blocking transmission
void UART_SendString(char *str) {
HAL_UART_Transmit(&huart2, (uint8_t*)str, strlen(str), 100);
}
// Blocking reception
void UART_ReceiveData(uint8_t *buffer, uint16_t size) {
HAL_UART_Receive(&huart2, buffer, size, 1000);
}
// Interrupt-based reception
void UART_ReceiveIT(uint8_t *buffer, uint16_t size) {
HAL_UART_Receive_IT(&huart2, buffer, size);
}
// Callback when reception complete
void HAL_UART_RxCpltCallback(UART_HandleTypeDef *huart) {
if (huart->Instance == USART2) {
// Process received data
// Re-enable reception
HAL_UART_Receive_IT(&huart2, rxBuffer, RX_BUFFER_SIZE);
}
}
// DMA-based high-speed transfer
void UART_Transmit_DMA(uint8_t *data, uint16_t size) {
HAL_UART_Transmit_DMA(&huart2, data, size);
}
Bare-Metal AVR (Arduino Uno)
#include <avr/io.h>
#define BAUD 9600
#define UBRR_VALUE ((F_CPU / 16 / BAUD) - 1)
void UART_Init(void) {
// Set baud rate
UBRR0H = (UBRR_VALUE >> 8);
UBRR0L = UBRR_VALUE;
// Enable transmitter and receiver
UCSR0B = (1 << TXEN0) | (1 << RXEN0);
// Set frame format: 8 data bits, 1 stop bit, no parity
UCSR0C = (1 << UCSZ01) | (1 << UCSZ00);
}
void UART_Transmit(uint8_t data) {
// Wait for empty transmit buffer
while (!(UCSR0A & (1 << UDRE0)));
// Put data into buffer, sends the data
UDR0 = data;
}
uint8_t UART_Receive(void) {
// Wait for data to be received
while (!(UCSR0A & (1 << RXC0)));
// Get and return received data from buffer
return UDR0;
}
void UART_Print(const char *str) {
while (*str) {
UART_Transmit(*str++);
}
}
Common Use Cases
1. Debugging and Logging
// Real-time debugging output
Serial.print("Sensor value: ");
Serial.println(sensorValue);
Serial.print("Free RAM: ");
Serial.println(freeRam());
2. GPS Module Communication
// Reading NMEA sentences from GPS
if (SerialGPS.available()) {
String nmea = SerialGPS.readStringUntil('\n');
if (nmea.startsWith("$GPGGA")) {
parseGPS(nmea);
}
}
3. Wireless Module (Bluetooth, WiFi)
// AT command interface
SerialBT.println("AT+NAME=MyDevice");
delay(100);
String response = SerialBT.readString();
4. Sensor Communication
// CO2 sensor command
Serial1.write(cmd, 9);
delay(100);
if (Serial1.available() >= 9) {
Serial1.readBytes(response, 9);
}
5. PC Communication
// Command protocol with PC
void loop() {
if (Serial.available()) {
char cmd = Serial.read();
switch(cmd) {
case 'L': digitalWrite(LED, HIGH); break;
case 'l': digitalWrite(LED, LOW); break;
case 'T': Serial.println(readTemp()); break;
}
}
}
UART vs Other Protocols
| Feature | UART | I2C | SPI |
|---|---|---|---|
| Wires | 2 (+ GND) | 2 | 4+ |
| Clock | Asynchronous | Synchronous | Synchronous |
| Devices | 2 (point-to-point) | Many (multi-master) | 1 master, many slaves |
| Speed | Up to ~5 Mbps | Up to 3.4 Mbps | Up to 50+ MHz |
| Distance | Long (meters) | Short (< 1m) | Short (< 1m) |
| Complexity | Simple | Medium | Simple |
| Error Detection | Parity bit | ACK/NACK | None |
Best Practices
1. Proper Baud Rate Calculation
// Ensure both devices use exact same baud rate
// Check oscillator tolerance - should be < 2%
// For custom baud rates, verify with formula:
// UBRR = (F_CPU / (16 * BAUD)) - 1
2. Buffer Management
// Check available space before reading
if (Serial.available() > 0) {
int bytesToRead = Serial.available();
for (int i = 0; i < bytesToRead; i++) {
rxBuffer[i] = Serial.read();
}
}
// Or use built-in methods
Serial.readBytes(rxBuffer, expectedSize);
3. Timeout Handling
// Set appropriate timeout
Serial.setTimeout(500); // 500ms
// Check for timeout
int value = Serial.parseInt();
if (value == 0 && Serial.peek() != '0') {
// Timeout occurred
Serial.println("Error: Timeout");
}
4. Flow Control (Hardware)
RTS (Request To Send) and CTS (Clear To Send)
Used for high-speed communications or when receiver
might not keep up with sender
5. Protocol Design
// Add framing for reliable communication
// Example: <START>DATA<END>
void sendPacket(uint8_t *data, uint8_t len) {
Serial.write(0x02); // STX (Start of Text)
for (int i = 0; i < len; i++) {
Serial.write(data[i]);
}
uint8_t checksum = calculateChecksum(data, len);
Serial.write(checksum);
Serial.write(0x03); // ETX (End of Text)
}
Common Issues and Debugging
Problem: Garbage Characters
Causes:
- Baud rate mismatch between devices
- Wrong oscillator frequency
- Noisy power supply
Solutions:
// Try common baud rates systematically
Serial.begin(9600); // Try this
Serial.begin(115200); // Then this
// Check your board's crystal frequency matches F_CPU
Problem: Missing Characters
Causes:
- Buffer overflow (data arriving faster than processing)
- Insufficient interrupt priority
Solutions:
// Increase serial buffer size (in HardwareSerial.cpp)
#define SERIAL_RX_BUFFER_SIZE 256
// Use hardware flow control
// Process data promptly in loop()
Problem: First Character Lost
Causes:
- Receiver not initialized before transmitter sends
- Start bit detection issue
Solutions:
// Add startup delay
void setup() {
Serial.begin(9600);
delay(100); // Wait for UART to stabilize
}
// Send dummy byte first
Serial.write(0x00);
delay(10);
Voltage Levels
TTL UART (3.3V or 5V)
- Logic HIGH: 2.4V - 5V
- Logic LOW: 0V - 0.8V
- Most microcontrollers use this
RS-232 UART (Legacy)
- Logic HIGH (Space): -3V to -15V
- Logic LOW (Mark): +3V to +15V
- Requires level shifter (MAX232, MAX3232)
- Longer cable runs possible
// Using MAX232 level shifter
// MCU TX -> MAX232 T1IN -> MAX232 T1OUT -> PC RX
// MCU RX <- MAX232 R1OUT <- MAX232 R1IN <- PC TX
ELI10 (Explain Like I'm 10)
Imagine you and your friend are in different rooms and want to talk using two cans connected by a string:
- TX (Transmit) is your mouth speaking into the can
- RX (Receive) is your ear listening from the can
- Baud Rate is how fast you talk - if one person talks super fast and the other listens slowly, you won't understand each other!
- Start Bit is like saying "Hey, listen!" before each word
- Stop Bit is like a pause after each word
The cool thing? Both of you can talk and listen at the same time because you have two strings (wires)!
The tricky part? You MUST both agree to talk at the same speed (baud rate) before starting, because there's no way to say "slow down!" once you've begun.
Further Resources
- UART Wikipedia
- SparkFun Serial Communication Tutorial
- Arduino Serial Reference
- AN4666: STM32 UART Concepts
- Baud Rate Calculator
USB Protocol
Comprehensive guide to USB protocol, device classes, and embedded implementation.
Table of Contents
- Introduction
- USB Basics
- USB Protocol
- Device Classes
- Descriptors
- Arduino USB
- STM32 USB
- USB CDC (Virtual Serial)
Introduction
USB (Universal Serial Bus) is a standard for connecting devices to a host computer. It provides both power and data communication in a single cable.
USB Versions
| Version | Name | Speed | Release | Connector |
|---|---|---|---|---|
| USB 1.0 | Low Speed | 1.5 Mbps | 1996 | Type A/B |
| USB 1.1 | Full Speed | 12 Mbps | 1998 | Type A/B |
| USB 2.0 | High Speed | 480 Mbps | 2000 | Type A/B, Mini, Micro |
| USB 3.0 | SuperSpeed | 5 Gbps | 2008 | Type A/B, Micro B SS |
| USB 3.1 | SuperSpeed+ | 10 Gbps | 2013 | Type C |
| USB 3.2 | - | 20 Gbps | 2017 | Type C |
| USB 4.0 | - | 40 Gbps | 2019 | Type C |
USB Connectors
USB Type A (Host):
┌─────────────┐
│ ┌─┐ ┌─┐ ┌─┐ │
│ │1│ │2│ │3│ │4│
│ └─┘ └─┘ └─┘ │
└─────────────┘
1: VBUS (+5V)
2: D- (Data -)
3: D+ (Data +)
4: GND
USB Type B (Device):
┌───┐
┌─┘ └─┐
│ 1 2 │
│ 3 4 │
└───────┘
USB Micro B (Common on embedded):
┌─────────┐
│1 2 3 4 5│
└─────────┘
1: VBUS (+5V)
2: D-
3: D+
4: ID (OTG)
5: GND
USB Type C (Modern):
┌───────────┐
│A1 A2...A12│
│B1 B2...B12│
└───────────┘
Reversible, 24 pins
USB Topology
Host (PC/Hub)
│
├─── Device 1 (Address 1)
│
├─── Hub (Address 2)
│ │
│ ├─── Device 2 (Address 3)
│ └─── Device 3 (Address 4)
│
└─── Device 4 (Address 5)
Maximum:
- 127 devices per host
- 5 meter cable length per segment
- 7 tiers (including hub)
USB Basics
Signal Levels
- Low Speed (1.5 Mbps): D+ pulled down, D- pulled up
- Full Speed (12 Mbps): D+ pulled up, D- pulled down
- High Speed (480 Mbps): Differential signaling
Power
USB 2.0: 5V, 500 mA max
USB 3.0: 5V, 900 mA max
USB-C PD: 5V, 9V, 15V, 20V up to 100W
Data Encoding
USB uses NRZI (Non-Return-to-Zero Inverted) encoding with bit stuffing:
0bit: Transition1bit: No transition- Bit stuffing: After six consecutive
1s, insert a0
Packet Types
Token Packets:
- SETUP: Initialize control transfer
- IN: Request data from device
- OUT: Send data to device
- SOF: Start of Frame (every 1 ms)
Data Packets:
- DATA0: Even data packet
- DATA1: Odd data packet
- DATA2: High-speed data
- MDATA: Multi-data
Handshake Packets:
- ACK: Acknowledge success
- NAK: Not ready
- STALL: Endpoint halted
Special Packets:
- PRE: Preamble for low-speed
- ERR: Error detected
- SPLIT: High-speed split transaction
USB Protocol
Enumeration Process
1. Device Plugged In
│
├─ USB Reset (SE0 for 10ms)
│
2. Host Assigns Address 0 (default)
│
├─ Get Device Descriptor
│ Response: VID, PID, max packet size
│
3. Host Assigns Unique Address (1-127)
│
├─ Set Address
│
4. Host Requests Configuration
│
├─ Get Configuration Descriptor
│ Response: Interfaces, endpoints, class info
│
├─ Get String Descriptors (optional)
│ Response: Manufacturer, product, serial
│
5. Host Configures Device
│
├─ Set Configuration
│
6. Device Ready for Use
Transfer Types
| Transfer Type | Speed | Error Correction | Use Case |
|---|---|---|---|
| Control | Any | Yes | Device enumeration, configuration |
| Bulk | Full/High | Yes | Large data transfers (storage, printers) |
| Interrupt | Any | Yes | Small, periodic data (HID, mice) |
| Isochronous | Full/High | No | Real-time audio/video |
Control Transfer Structure
Setup Stage:
Host → Device: SETUP token + DATA0 packet
Data Stage (optional):
IN: Device → Host: DATA packets
OUT: Host → Device: DATA packets
Status Stage:
IN: Device → Host: Zero-length DATA1 + ACK
OUT: Host → Device: Zero-length DATA1 + ACK
Standard Requests
// bmRequestType: Direction | Type | Recipient
#define USB_DIR_OUT 0x00
#define USB_DIR_IN 0x80
#define USB_TYPE_STANDARD 0x00
#define USB_TYPE_CLASS 0x20
#define USB_TYPE_VENDOR 0x40
#define USB_RECIP_DEVICE 0x00
#define USB_RECIP_INTERFACE 0x01
#define USB_RECIP_ENDPOINT 0x02
// bRequest codes
#define USB_REQ_GET_STATUS 0
#define USB_REQ_CLEAR_FEATURE 1
#define USB_REQ_SET_FEATURE 3
#define USB_REQ_SET_ADDRESS 5
#define USB_REQ_GET_DESCRIPTOR 6
#define USB_REQ_SET_DESCRIPTOR 7
#define USB_REQ_GET_CONFIGURATION 8
#define USB_REQ_SET_CONFIGURATION 9
#define USB_REQ_GET_INTERFACE 10
#define USB_REQ_SET_INTERFACE 11
Device Classes
USB Class Codes
| Class | Code | Description | Examples |
|---|---|---|---|
| CDC | 0x02 | Communications Device | Virtual COM port, modems |
| HID | 0x03 | Human Interface Device | Keyboards, mice, game controllers |
| Mass Storage | 0x08 | Storage Device | USB flash drives, external HDDs |
| Hub | 0x09 | USB Hub | - |
| Audio | 0x01 | Audio Device | Speakers, microphones |
| Video | 0x0E | Video Device | Webcams |
| Printer | 0x07 | Printer | - |
| Vendor Specific | 0xFF | Custom | - |
HID (Human Interface Device)
// HID Descriptor
struct HID_Descriptor {
uint8_t bLength; // Size of descriptor
uint8_t bDescriptorType; // HID descriptor type (0x21)
uint16_t bcdHID; // HID specification release
uint8_t bCountryCode; // Country code
uint8_t bNumDescriptors; // Number of class descriptors
uint8_t bDescriptorType2; // Report descriptor type (0x22)
uint16_t wDescriptorLength; // Length of report descriptor
};
// HID Report Descriptor (Mouse example)
const uint8_t mouse_report_descriptor[] = {
0x05, 0x01, // Usage Page (Generic Desktop)
0x09, 0x02, // Usage (Mouse)
0xA1, 0x01, // Collection (Application)
0x09, 0x01, // Usage (Pointer)
0xA1, 0x00, // Collection (Physical)
0x05, 0x09, // Usage Page (Buttons)
0x19, 0x01, // Usage Minimum (Button 1)
0x29, 0x03, // Usage Maximum (Button 3)
0x15, 0x00, // Logical Minimum (0)
0x25, 0x01, // Logical Maximum (1)
0x95, 0x03, // Report Count (3)
0x75, 0x01, // Report Size (1)
0x81, 0x02, // Input (Data, Variable, Absolute)
0x95, 0x01, // Report Count (1)
0x75, 0x05, // Report Size (5)
0x81, 0x01, // Input (Constant) - Padding
0x05, 0x01, // Usage Page (Generic Desktop)
0x09, 0x30, // Usage (X)
0x09, 0x31, // Usage (Y)
0x15, 0x81, // Logical Minimum (-127)
0x25, 0x7F, // Logical Maximum (127)
0x75, 0x08, // Report Size (8)
0x95, 0x02, // Report Count (2)
0x81, 0x06, // Input (Data, Variable, Relative)
0xC0, // End Collection
0xC0 // End Collection
};
CDC (Communication Device Class)
Used for virtual serial ports (USB to UART).
// CDC ACM (Abstract Control Model) Interface
// CDC Header Functional Descriptor
struct CDC_Header_Descriptor {
uint8_t bLength;
uint8_t bDescriptorType;
uint8_t bDescriptorSubtype; // Header (0x00)
uint16_t bcdCDC;
};
// CDC Call Management Descriptor
struct CDC_CallManagement_Descriptor {
uint8_t bLength;
uint8_t bDescriptorType;
uint8_t bDescriptorSubtype; // Call Management (0x01)
uint8_t bmCapabilities;
uint8_t bDataInterface;
};
// CDC Line Coding (115200 8N1 example)
struct CDC_LineCoding {
uint32_t dwDTERate; // Baud rate: 115200
uint8_t bCharFormat; // Stop bits: 1
uint8_t bParityType; // Parity: None (0)
uint8_t bDataBits; // Data bits: 8
};
Descriptors
Device Descriptor
struct USB_Device_Descriptor {
uint8_t bLength; // Size: 18 bytes
uint8_t bDescriptorType; // DEVICE (0x01)
uint16_t bcdUSB; // USB version (0x0200 for USB 2.0)
uint8_t bDeviceClass; // Class code
uint8_t bDeviceSubClass; // Subclass code
uint8_t bDeviceProtocol; // Protocol code
uint8_t bMaxPacketSize0; // Max packet size for EP0
uint16_t idVendor; // Vendor ID (VID)
uint16_t idProduct; // Product ID (PID)
uint16_t bcdDevice; // Device release number
uint8_t iManufacturer; // Manufacturer string index
uint8_t iProduct; // Product string index
uint8_t iSerialNumber; // Serial number string index
uint8_t bNumConfigurations; // Number of configurations
};
// Example
const uint8_t device_descriptor[] = {
18, // bLength
0x01, // bDescriptorType (DEVICE)
0x00, 0x02, // bcdUSB (USB 2.0)
0x00, // bDeviceClass (defined in interface)
0x00, // bDeviceSubClass
0x00, // bDeviceProtocol
64, // bMaxPacketSize0
0x83, 0x04, // idVendor (0x0483 - STMicroelectronics)
0x40, 0x57, // idProduct (0x5740)
0x00, 0x02, // bcdDevice (2.0)
1, // iManufacturer
2, // iProduct
3, // iSerialNumber
1 // bNumConfigurations
};
Configuration Descriptor
struct USB_Configuration_Descriptor {
uint8_t bLength; // Size: 9 bytes
uint8_t bDescriptorType; // CONFIGURATION (0x02)
uint16_t wTotalLength; // Total length of data
uint8_t bNumInterfaces; // Number of interfaces
uint8_t bConfigurationValue; // Configuration index
uint8_t iConfiguration; // Configuration string index
uint8_t bmAttributes; // Attributes (self/bus powered)
uint8_t bMaxPower; // Max power in 2mA units
};
Interface Descriptor
struct USB_Interface_Descriptor {
uint8_t bLength; // Size: 9 bytes
uint8_t bDescriptorType; // INTERFACE (0x04)
uint8_t bInterfaceNumber; // Interface index
uint8_t bAlternateSetting; // Alternate setting
uint8_t bNumEndpoints; // Number of endpoints
uint8_t bInterfaceClass; // Class code
uint8_t bInterfaceSubClass; // Subclass code
uint8_t bInterfaceProtocol; // Protocol code
uint8_t iInterface; // Interface string index
};
Endpoint Descriptor
struct USB_Endpoint_Descriptor {
uint8_t bLength; // Size: 7 bytes
uint8_t bDescriptorType; // ENDPOINT (0x05)
uint8_t bEndpointAddress; // Address (bit 7: direction)
uint8_t bmAttributes; // Transfer type
uint16_t wMaxPacketSize; // Max packet size
uint8_t bInterval; // Polling interval (ms)
};
// Endpoint address format:
// Bit 7: Direction (0 = OUT, 1 = IN)
// Bits 3-0: Endpoint number (0-15)
#define USB_EP_IN(n) (0x80 | (n))
#define USB_EP_OUT(n) (n)
// Transfer types
#define USB_EP_TYPE_CONTROL 0x00
#define USB_EP_TYPE_ISOCHRONOUS 0x01
#define USB_EP_TYPE_BULK 0x02
#define USB_EP_TYPE_INTERRUPT 0x03
String Descriptor
struct USB_String_Descriptor {
uint8_t bLength;
uint8_t bDescriptorType; // STRING (0x03)
uint16_t wString[]; // Unicode string
};
// String 0 (Language ID)
const uint8_t string0[] = {
4, // bLength
0x03, // bDescriptorType
0x09, 0x04 // wLANGID[0]: 0x0409 (English - US)
};
// String 1 (Manufacturer)
const uint8_t string1[] = {
28, // bLength
0x03, // bDescriptorType
'M',0, 'a',0, 'n',0, 'u',0, 'f',0, 'a',0, 'c',0,
't',0, 'u',0, 'r',0, 'e',0, 'r',0, 0,0
};
Arduino USB
Arduino Leonardo/Micro (ATmega32u4)
The ATmega32u4 has native USB support.
USB Mouse
#include <Mouse.h>
void setup() {
Mouse.begin();
}
void loop() {
// Move mouse in a square
Mouse.move(10, 0); // Right
delay(500);
Mouse.move(0, 10); // Down
delay(500);
Mouse.move(-10, 0); // Left
delay(500);
Mouse.move(0, -10); // Up
delay(500);
}
USB Keyboard
#include <Keyboard.h>
const int BUTTON_PIN = 2;
void setup() {
pinMode(BUTTON_PIN, INPUT_PULLUP);
Keyboard.begin();
}
void loop() {
if (digitalRead(BUTTON_PIN) == LOW) {
Keyboard.print("Hello, World!");
delay(500);
}
}
USB HID Custom
#include <HID.h>
// Custom HID report descriptor
static const uint8_t _hidReportDescriptor[] PROGMEM = {
0x06, 0x00, 0xFF, // Usage Page (Vendor Defined)
0x09, 0x01, // Usage (Vendor Usage 1)
0xA1, 0x01, // Collection (Application)
0x15, 0x00, // Logical Minimum (0)
0x26, 0xFF, 0x00, // Logical Maximum (255)
0x75, 0x08, // Report Size (8 bits)
0x95, 0x40, // Report Count (64)
0x09, 0x01, // Usage (Vendor Usage 1)
0x81, 0x02, // Input (Data, Variable, Absolute)
0x09, 0x01, // Usage (Vendor Usage 1)
0x91, 0x02, // Output (Data, Variable, Absolute)
0xC0 // End Collection
};
void setup() {
static HIDSubDescriptor node(_hidReportDescriptor, sizeof(_hidReportDescriptor));
HID().AppendDescriptor(&node);
}
void loop() {
uint8_t data[64] = {1, 2, 3, 4};
HID().SendReport(1, data, 64);
delay(100);
}
STM32 USB
USB CDC Virtual COM Port (CubeMX)
/* Generated by CubeMX with USB Device middleware */
#include "usbd_cdc_if.h"
int main(void) {
HAL_Init();
SystemClock_Config();
MX_USB_DEVICE_Init();
uint8_t buffer[64];
sprintf((char*)buffer, "Hello from STM32!\r\n");
while (1) {
CDC_Transmit_FS(buffer, strlen((char*)buffer));
HAL_Delay(1000);
}
}
/* In usbd_cdc_if.c */
static int8_t CDC_Receive_FS(uint8_t* Buf, uint32_t *Len) {
// Echo back received data
CDC_Transmit_FS(Buf, *Len);
return USBD_OK;
}
USB HID Keyboard
/* Configure USB Device as HID in CubeMX */
#include "usbd_hid.h"
extern USBD_HandleTypeDef hUsbDeviceFS;
// HID keyboard report
typedef struct {
uint8_t modifiers; // Ctrl, Shift, Alt, GUI
uint8_t reserved;
uint8_t keys[6]; // Up to 6 simultaneous keys
} KeyboardReport;
void send_key(uint8_t key) {
KeyboardReport report = {0};
// Press key
report.keys[0] = key;
USBD_HID_SendReport(&hUsbDeviceFS, (uint8_t*)&report, sizeof(report));
HAL_Delay(10);
// Release key
memset(&report, 0, sizeof(report));
USBD_HID_SendReport(&hUsbDeviceFS, (uint8_t*)&report, sizeof(report));
HAL_Delay(10);
}
int main(void) {
HAL_Init();
SystemClock_Config();
MX_USB_DEVICE_Init();
HAL_Delay(1000); // Wait for enumeration
while (1) {
// Send 'A' key
send_key(0x04); // HID usage code for 'A'
HAL_Delay(1000);
}
}
USB Mass Storage
/* Configure USB Device as MSC in CubeMX */
#include "usbd_storage_if.h"
// Implement SCSI commands
int8_t STORAGE_Read_FS(uint8_t lun, uint8_t *buf, uint32_t blk_addr, uint16_t blk_len) {
// Read from SD card or internal flash
for (uint16_t i = 0; i < blk_len; i++) {
// Read block at (blk_addr + i) to (buf + i * BLOCK_SIZE)
}
return USBD_OK;
}
int8_t STORAGE_Write_FS(uint8_t lun, uint8_t *buf, uint32_t blk_addr, uint16_t blk_len) {
// Write to SD card or internal flash
for (uint16_t i = 0; i < blk_len; i++) {
// Write block at (blk_addr + i) from (buf + i * BLOCK_SIZE)
}
return USBD_OK;
}
USB CDC (Virtual Serial)
PC Side (Python)
import serial
import time
# Open serial port
ser = serial.Serial('COM3', 115200, timeout=1) # Windows
# ser = serial.Serial('/dev/ttyACM0', 115200, timeout=1) # Linux
# Write data
ser.write(b'Hello, Device!\n')
# Read data
while True:
if ser.in_waiting > 0:
data = ser.readline()
print(f"Received: {data.decode()}")
time.sleep(0.1)
ser.close()
PC Side (C++)
// Linux example
#include <fcntl.h>
#include <unistd.h>
#include <termios.h>
int main() {
int fd = open("/dev/ttyACM0", O_RDWR);
// Configure serial port
struct termios tty;
tcgetattr(fd, &tty);
cfsetospeed(&tty, B115200);
cfsetispeed(&tty, B115200);
tty.c_cflag |= (CLOCAL | CREAD);
tty.c_cflag &= ~PARENB;
tty.c_cflag &= ~CSTOPB;
tty.c_cflag &= ~CSIZE;
tty.c_cflag |= CS8;
tcsetattr(fd, TCSANOW, &tty);
// Write
char msg[] = "Hello, Device!\n";
write(fd, msg, sizeof(msg));
// Read
char buffer[256];
int n = read(fd, buffer, sizeof(buffer));
buffer[n] = '\0';
printf("Received: %s\n", buffer);
close(fd);
return 0;
}
Best Practices
- VID/PID: Use unique Vendor ID and Product ID (or get your own)
- Descriptors: Ensure correct descriptor chain
- Enumeration: Handle USB reset and enumeration properly
- Power: Declare correct power consumption
- String Descriptors: Provide manufacturer, product, serial number
- Error Handling: Handle NAK, STALL conditions
- Buffer Management: Use DMA for better performance
- Compliance: Test with USB-IF tools for certification
Debugging Tools
Linux
# List USB devices
lsusb
# Detailed info
lsusb -v
# Monitor USB traffic
sudo cat /sys/kernel/debug/usb/usbmon/0u
# Install usbutils
sudo apt install usbutils
Windows
- USBView: Microsoft USB device viewer
- USBDeview: NirSoft utility
- Wireshark: With USB capture support
Hardware
- USB Protocol Analyzer: Beagle USB, Total Phase
- Logic Analyzer: Can decode USB signals
Troubleshooting
Common Issues
Device Not Recognized:
- Check USB cable (data lines)
- Verify correct descriptors
- Check VID/PID not conflicting
- Ensure proper enumeration handling
Intermittent Disconnects:
- Power supply insufficient
- Check USB cable quality
- Verify proper suspend/resume handling
Data Corruption:
- Check buffer sizes
- Verify DMA configuration
- Ensure proper synchronization
Slow Transfer Speed:
- Use bulk transfers for large data
- Enable DMA
- Optimize buffer sizes
- Check USB 2.0 High Speed mode
Resources
- USB Specification: USB.org
- USB Made Simple: https://www.usbmadesimple.co.uk/
- STM32 USB Training: ST's USB training materials
- Jan Axelson's USB: Classic USB development book
- Linux USB: https://www.kernel.org/doc/html/latest/driver-api/usb/
See Also
CAN (Controller Area Network)
Controller Area Network (CAN) is a robust vehicle bus standard designed to allow microcontrollers and devices to communicate with each other without a host computer. It is widely used in automotive and industrial applications due to its reliability and efficiency.
Key Concepts
-
Frames: CAN communication is based on frames, which are structured packets of data. Each frame contains an identifier, control bits, data, and error-checking information.
-
Identifiers: Each frame has a unique identifier that determines the priority of the message. Lower identifier values have higher priority on the bus.
-
Bitwise Arbitration: CAN uses a non-destructive bitwise arbitration method to control access to the bus. This ensures that the highest priority message is transmitted without collision.
Common Standards
- CAN 2.0A: This standard defines 11-bit identifiers for frames.
- CAN 2.0B: This standard extends the identifier length to 29 bits, allowing for more unique message identifiers.
- CAN FD (Flexible Data-rate): This standard allows for higher data rates and larger data payloads compared to traditional CAN.
Applications
CAN is used in various applications, including:
- Automotive: Enabling communication between different electronic control units (ECUs) in vehicles, such as engine control, transmission, and braking systems.
- Industrial Automation: Facilitating communication between sensors, actuators, and controllers in manufacturing and process control systems.
- Medical Equipment: Ensuring reliable data exchange between different components of medical devices.
Conclusion
CAN is a critical communication protocol in automotive and industrial systems, providing reliable and efficient data exchange. Understanding CAN's principles and standards is essential for engineers working in these fields.
SDIO
Overview
SDIO (Secure Digital Input Output) is an extension of the SD (Secure Digital) card standard that allows for the integration of input/output devices into the SD card interface. This enables various peripherals, such as Wi-Fi, Bluetooth, GPS, and other sensors, to be connected to a host device through a standard SD card slot.
Key Features
- Versatility: SDIO supports a wide range of devices, making it suitable for various applications in mobile devices, embedded systems, and consumer electronics.
- Hot Swappable: SDIO devices can be inserted and removed while the host device is powered on, allowing for greater flexibility in device management.
- Standardized Interface: The SDIO interface is standardized, which simplifies the development process for manufacturers and developers.
Applications
SDIO is commonly used in:
- Wireless Communication: Many Wi-Fi and Bluetooth modules utilize SDIO to connect to host devices, enabling wireless connectivity.
- GPS Modules: GPS receivers can be integrated via SDIO, providing location services to mobile devices.
- Sensor Integration: Various sensors, such as accelerometers and gyroscopes, can be connected through SDIO for enhanced functionality in applications like gaming and navigation.
Signals
In the context of SDIO, signals refer to the electrical signals used for communication between the host device and the SDIO peripheral. These signals are essential for data transfer, command execution, and device management. The key signals in the SDIO interface include:
-
CMD (Command Line): This signal is used to send commands from the host to the SDIO device. It is essential for initiating communication and controlling the operation of the device.
-
CLK (Clock Line): The clock signal synchronizes the data transfer between the host and the SDIO device. It ensures that both the host and the device are in sync during communication.
-
DATA (Data Lines): These lines are used for data transfer between the host and the SDIO device. SDIO supports multiple data lines (typically 1, 4, or 8) to increase the data transfer rate.
-
CD (Card Detect): This signal indicates whether an SDIO device is present in the slot. It allows the host to detect when a device is inserted or removed.
-
WP (Write Protect): This signal is used to indicate whether the SDIO device is write-protected. It prevents accidental data modification when the device is in a write-protect state.
Conclusion
SDIO is a powerful extension of the SD card standard that enhances the capabilities of mobile and embedded devices by allowing the integration of various peripherals. Its versatility and standardized interface make it a popular choice for developers looking to expand the functionality of their devices.
Ethernet
Ethernet is a widely used networking technology that enables devices to communicate over a local area network (LAN). It is a fundamental technology for connecting computers, printers, and other devices in homes and businesses.
Key Concepts
-
Frames: Ethernet transmits data in packets called frames. Each frame contains source and destination MAC addresses, as well as the data being transmitted.
-
MAC Address: A Media Access Control (MAC) address is a unique identifier assigned to network interfaces for communication on the physical network segment.
-
Switching: Ethernet switches are devices that connect multiple devices on a LAN and use MAC addresses to forward frames to the correct destination.
Common Standards
-
IEEE 802.3: This is the standard that defines the physical and data link layers for Ethernet networks. It includes specifications for various speeds, such as 10 Mbps, 100 Mbps, 1 Gbps, and 10 Gbps.
-
Full Duplex: Modern Ethernet supports full duplex communication, allowing devices to send and receive data simultaneously, which improves network efficiency.
-
VLANs: Virtual Local Area Networks (VLANs) allow network administrators to segment a single physical network into multiple logical networks for improved security and performance.
Applications
Ethernet is used in various applications, including:
-
Local Area Networking: Connecting computers and devices within a limited geographical area, such as an office or home.
-
Data Centers: Providing high-speed connections between servers and storage devices.
-
Industrial Automation: Enabling communication between machines and control systems in manufacturing environments.
Different Signals in Ethernet
Ethernet communication relies on various signals to transmit data over the network. These signals include:
-
Carrier Sense: Ethernet devices use carrier sense to detect if the network medium is idle or busy before transmitting data. This helps prevent collisions on the network.
-
Collision Detection: In half-duplex Ethernet, devices use collision detection to identify when two devices transmit data simultaneously, causing a collision. When a collision is detected, devices stop transmitting and wait for a random backoff period before attempting to retransmit.
-
Preamble: Each Ethernet frame begins with a preamble, a sequence of alternating 1s and 0s, which allows devices to synchronize their clocks and prepare for the incoming data.
-
Start Frame Delimiter (SFD): Following the preamble, the SFD is a specific pattern that indicates the start of the actual Ethernet frame.
-
Clock Signals: Ethernet devices use clock signals to maintain synchronization between the transmitter and receiver, ensuring accurate data transmission.
-
Link Pulse: In 10BASE-T Ethernet, link pulses are used to establish and maintain a connection between devices. These pulses are sent periodically to indicate that the link is active.
Understanding these signals is crucial for diagnosing and troubleshooting Ethernet network issues, as well as for designing and implementing reliable Ethernet communication systems.
Conclusion
Ethernet remains a cornerstone of modern networking, providing reliable and high-speed communication for a wide range of applications. Understanding Ethernet's principles and standards is essential for network engineers and IT professionals.
PWM (Pulse Width Modulation)
Overview
Pulse Width Modulation (PWM) is a technique for controlling power delivery to electrical devices by rapidly switching between ON and OFF states. By varying the ratio of ON time to OFF time (duty cycle), you can control the average power delivered without actually changing the voltage level. This makes PWM highly efficient and versatile for applications ranging from LED dimming to motor control.
Key Concepts
Duty Cycle
The duty cycle is the percentage of time the signal is HIGH during one complete cycle.
Duty Cycle (%) = (Ton / (Ton + Toff)) × 100
Where:
- Ton = Time the signal is HIGH
- Toff = Time the signal is LOW
- Period = Ton + Toff
Examples:
- 0% duty cycle: Always LOW (0V average)
- 25% duty cycle: HIGH for 1/4 of the period
- 50% duty cycle: HIGH for half the period
- 75% duty cycle: HIGH for 3/4 of the period
- 100% duty cycle: Always HIGH (full voltage)
Frequency
The frequency determines how many ON/OFF cycles occur per second, measured in Hertz (Hz).
Frequency = 1 / Period
Period = 1 / Frequency
Typical Frequencies:
- LED Dimming: 500 Hz - 20 kHz (above flicker perception ~60 Hz)
- Motor Control: 1 kHz - 40 kHz
- Audio: 40 kHz+ (above human hearing)
- Servo Motors: 50 Hz (20ms period)
Average Voltage
The average voltage delivered by PWM:
Average Voltage = Supply Voltage × (Duty Cycle / 100)
Example (5V supply):
- 0% duty → 0V average
- 25% duty → 1.25V average
- 50% duty → 2.5V average
- 100% duty → 5V average
Visual Representation
100% Duty Cycle (Always ON):
█████████████████████████████
75% Duty Cycle:
██████████████████░░░░░░░
50% Duty Cycle:
█████████████░░░░░░░░░░░░
25% Duty Cycle:
██████░░░░░░░░░░░░░░░░░░░
0% Duty Cycle (Always OFF):
░░░░░░░░░░░░░░░░░░░░░░░░░
How It Works
Hardware PWM vs Software PWM
| Feature | Hardware PWM | Software PWM |
|---|---|---|
| Precision | Very precise, timer-based | Can jitter with interrupts |
| CPU Load | Zero (handled by hardware) | High (CPU must toggle pin) |
| Pins | Limited (specific pins only) | Any digital pin |
| Frequency | High (up to MHz) | Low (few kHz max) |
| Recommended | Motors, audio, servos | Simple LED control |
PWM Resolution
Resolution is the number of distinct duty cycle levels available:
| Resolution | Levels | Step Size (at 5V) |
|---|---|---|
| 8-bit | 256 | 19.5 mV |
| 10-bit | 1024 | 4.88 mV |
| 12-bit | 4096 | 1.22 mV |
| 16-bit | 65536 | 76 μV |
Note: Higher resolution requires lower maximum frequency:
Max Frequency = Clock Frequency / (2^Resolution)
Code Examples
Arduino PWM (Hardware)
// Arduino Uno PWM pins: 3, 5, 6, 9, 10, 11
// Default frequency: ~490 Hz (pins 3,9,10,11) and ~980 Hz (pins 5,6)
const int ledPin = 9;
const int motorPin = 10;
void setup() {
pinMode(ledPin, OUTPUT);
pinMode(motorPin, OUTPUT);
}
void loop() {
// analogWrite uses 8-bit resolution (0-255)
// LED at 25% brightness
analogWrite(ledPin, 64); // 64/255 = 25%
delay(1000);
// LED at 50% brightness
analogWrite(ledPin, 128); // 128/255 = 50%
delay(1000);
// LED at 75% brightness
analogWrite(ledPin, 192); // 192/255 = 75%
delay(1000);
// LED at 100% brightness
analogWrite(ledPin, 255); // 255/255 = 100%
delay(1000);
}
// Smooth fade effect
void fadeLED() {
// Fade in
for (int brightness = 0; brightness <= 255; brightness++) {
analogWrite(ledPin, brightness);
delay(5);
}
// Fade out
for (int brightness = 255; brightness >= 0; brightness--) {
analogWrite(ledPin, brightness);
delay(5);
}
}
// Change PWM frequency (Arduino Uno)
void setPWMFrequency(int pin, int divisor) {
byte mode;
if (pin == 5 || pin == 6 || pin == 9 || pin == 10) {
switch(divisor) {
case 1: mode = 0x01; break; // 31.25 kHz
case 8: mode = 0x02; break; // 3.9 kHz
case 64: mode = 0x03; break; // 490 Hz (default for 9,10)
case 256: mode = 0x04; break; // 122 Hz
case 1024: mode = 0x05; break; // 30 Hz
default: return;
}
if (pin == 5 || pin == 6) {
TCCR0B = (TCCR0B & 0b11111000) | mode;
} else {
TCCR1B = (TCCR1B & 0b11111000) | mode;
}
}
}
void setup() {
pinMode(9, OUTPUT);
setPWMFrequency(9, 1); // Set pin 9 to 31.25 kHz
}
ESP32 PWM (LEDC)
// ESP32 uses LEDC (LED Control) for PWM
// 16 independent channels, configurable frequency and resolution
const int ledPin = 25;
const int pwmChannel = 0; // Channel 0-15
const int pwmFrequency = 5000; // 5 kHz
const int pwmResolution = 8; // 8-bit (0-255)
void setup() {
// Configure PWM channel
ledcSetup(pwmChannel, pwmFrequency, pwmResolution);
// Attach pin to PWM channel
ledcAttachPin(ledPin, pwmChannel);
}
void loop() {
// Set duty cycle (0-255 for 8-bit)
ledcWrite(pwmChannel, 128); // 50% duty cycle
delay(1000);
ledcWrite(pwmChannel, 64); // 25% duty cycle
delay(1000);
}
// High resolution PWM (16-bit)
void setupHighResPWM() {
const int pwmChannel = 0;
const int pwmFreq = 1000; // Lower freq for higher resolution
const int pwmRes = 16; // 16-bit (0-65535)
ledcSetup(pwmChannel, pwmFreq, pwmRes);
ledcAttachPin(ledPin, pwmChannel);
// Set to 50% with 16-bit precision
ledcWrite(pwmChannel, 32768);
}
// Multiple PWM channels for RGB LED
const int redPin = 25;
const int greenPin = 26;
const int bluePin = 27;
void setupRGB() {
ledcSetup(0, 5000, 8); // Red channel
ledcSetup(1, 5000, 8); // Green channel
ledcSetup(2, 5000, 8); // Blue channel
ledcAttachPin(redPin, 0);
ledcAttachPin(greenPin, 1);
ledcAttachPin(bluePin, 2);
}
void setRGBColor(uint8_t r, uint8_t g, uint8_t b) {
ledcWrite(0, r);
ledcWrite(1, g);
ledcWrite(2, b);
}
void loop() {
setRGBColor(255, 0, 0); // Red
delay(1000);
setRGBColor(0, 255, 0); // Green
delay(1000);
setRGBColor(0, 0, 255); // Blue
delay(1000);
setRGBColor(255, 255, 0); // Yellow
delay(1000);
}
STM32 HAL PWM
#include "stm32f4xx_hal.h"
TIM_HandleTypeDef htim3;
void PWM_Init(void) {
TIM_OC_InitTypeDef sConfigOC = {0};
// Timer 3 configuration for PWM
htim3.Instance = TIM3;
htim3.Init.Prescaler = 84 - 1; // 84 MHz / 84 = 1 MHz
htim3.Init.CounterMode = TIM_COUNTERMODE_UP;
htim3.Init.Period = 1000 - 1; // 1 MHz / 1000 = 1 kHz PWM
htim3.Init.ClockDivision = TIM_CLOCKDIVISION_DIV1;
HAL_TIM_PWM_Init(&htim3);
// Configure PWM channel 1
sConfigOC.OCMode = TIM_OCMODE_PWM1;
sConfigOC.Pulse = 500; // 50% duty cycle
sConfigOC.OCPolarity = TIM_OCPOLARITY_HIGH;
sConfigOC.OCFastMode = TIM_OCFAST_DISABLE;
HAL_TIM_PWM_ConfigChannel(&htim3, &sConfigOC, TIM_CHANNEL_1);
// Start PWM
HAL_TIM_PWM_Start(&htim3, TIM_CHANNEL_1);
}
void PWM_SetDutyCycle(uint16_t dutyCycle) {
// dutyCycle: 0-1000 (0-100%)
__HAL_TIM_SET_COMPARE(&htim3, TIM_CHANNEL_1, dutyCycle);
}
void PWM_SetPercent(uint8_t percent) {
// percent: 0-100
uint16_t pulse = (percent * 1000) / 100;
PWM_SetDutyCycle(pulse);
}
Servo Control with PWM
// Standard servo: 50 Hz (20ms period)
// Pulse width: 1ms (0°) to 2ms (180°)
const int servoPin = 9;
void setup() {
pinMode(servoPin, OUTPUT);
}
void setServoAngle(int angle) {
// Map angle (0-180) to pulse width (1000-2000 μs)
int pulseWidth = map(angle, 0, 180, 1000, 2000);
// Generate 50 Hz PWM signal
digitalWrite(servoPin, HIGH);
delayMicroseconds(pulseWidth);
digitalWrite(servoPin, LOW);
delayMicroseconds(20000 - pulseWidth); // Complete 20ms period
}
void loop() {
setServoAngle(0); // 0 degrees
delay(1000);
setServoAngle(90); // 90 degrees
delay(1000);
setServoAngle(180); // 180 degrees
delay(1000);
}
// Using Servo library (easier)
#include <Servo.h>
Servo myServo;
void setup() {
myServo.attach(9); // Attach servo to pin 9
}
void loop() {
myServo.write(0); // 0 degrees
delay(1000);
myServo.write(90); // 90 degrees
delay(1000);
myServo.write(180); // 180 degrees
delay(1000);
}
Motor Control (H-Bridge)
// Control DC motor speed and direction with L298N H-Bridge
const int motorPWM = 9; // Speed control (PWM)
const int motorIN1 = 7; // Direction control
const int motorIN2 = 8; // Direction control
void setup() {
pinMode(motorPWM, OUTPUT);
pinMode(motorIN1, OUTPUT);
pinMode(motorIN2, OUTPUT);
}
void setMotorSpeed(int speed) {
// speed: -255 (full reverse) to +255 (full forward)
if (speed > 0) {
// Forward
digitalWrite(motorIN1, HIGH);
digitalWrite(motorIN2, LOW);
analogWrite(motorPWM, speed);
} else if (speed < 0) {
// Reverse
digitalWrite(motorIN1, LOW);
digitalWrite(motorIN2, HIGH);
analogWrite(motorPWM, -speed);
} else {
// Stop
digitalWrite(motorIN1, LOW);
digitalWrite(motorIN2, LOW);
analogWrite(motorPWM, 0);
}
}
void loop() {
setMotorSpeed(128); // 50% forward
delay(2000);
setMotorSpeed(255); // 100% forward
delay(2000);
setMotorSpeed(0); // Stop
delay(1000);
setMotorSpeed(-128); // 50% reverse
delay(2000);
}
Common Applications
1. LED Dimming
// Smooth breathing effect
void breathingLED(int pin) {
const int maxBrightness = 255;
const int minBrightness = 0;
const int step = 5;
const int delayTime = 30;
// Breathe in
for (int brightness = minBrightness; brightness <= maxBrightness; brightness += step) {
analogWrite(pin, brightness);
delay(delayTime);
}
// Breathe out
for (int brightness = maxBrightness; brightness >= minBrightness; brightness -= step) {
analogWrite(pin, brightness);
delay(delayTime);
}
}
2. RGB Color Mixing
void setColorHSV(float h, float s, float v) {
// Convert HSV to RGB
float c = v * s;
float x = c * (1 - abs(fmod(h / 60.0, 2) - 1));
float m = v - c;
float r, g, b;
if (h < 60) { r = c; g = x; b = 0; }
else if (h < 120) { r = x; g = c; b = 0; }
else if (h < 180) { r = 0; g = c; b = x; }
else if (h < 240) { r = 0; g = x; b = c; }
else if (h < 300) { r = x; g = 0; b = c; }
else { r = c; g = 0; b = x; }
analogWrite(redPin, (r + m) * 255);
analogWrite(greenPin, (g + m) * 255);
analogWrite(bluePin, (b + m) * 255);
}
// Rainbow effect
void rainbow() {
for (int hue = 0; hue < 360; hue++) {
setColorHSV(hue, 1.0, 1.0);
delay(10);
}
}
3. Speaker/Buzzer Tone Generation
void playTone(int pin, int frequency, int duration) {
int period = 1000000 / frequency; // Period in microseconds
int halfPeriod = period / 2;
long cycles = ((long)frequency * duration) / 1000;
for (long i = 0; i < cycles; i++) {
digitalWrite(pin, HIGH);
delayMicroseconds(halfPeriod);
digitalWrite(pin, LOW);
delayMicroseconds(halfPeriod);
}
}
void playMelody() {
playTone(buzzerPin, 262, 500); // C4
playTone(buzzerPin, 294, 500); // D4
playTone(buzzerPin, 330, 500); // E4
playTone(buzzerPin, 349, 500); // F4
}
// Using tone() function (easier)
void playNote(int frequency) {
tone(buzzerPin, frequency);
delay(500);
noTone(buzzerPin);
}
4. Fan Speed Control
int targetTemp = 25; // Target temperature
int currentTemp = 30; // Read from sensor
void controlFan() {
int tempDiff = currentTemp - targetTemp;
int fanSpeed;
if (tempDiff <= 0) {
fanSpeed = 0; // Too cold, fan off
} else if (tempDiff >= 10) {
fanSpeed = 255; // Very hot, max speed
} else {
// Proportional control
fanSpeed = map(tempDiff, 0, 10, 50, 255);
}
analogWrite(fanPin, fanSpeed);
}
5. Power Supply (Buck Converter)
PWM is used in switching power supplies to efficiently convert voltage:
- High frequency (20-100 kHz) minimizes inductor size
- Duty cycle controls output voltage
- Feedback loop maintains regulation
Best Practices
1. Choose Appropriate Frequency
// LED dimming: Use higher frequency to avoid flicker
// Human eye perceives flicker below ~60 Hz
setPWMFrequency(ledPin, 1); // 31 kHz - no visible flicker
// Motor control: Balance between smoothness and efficiency
// Too high: Increased switching losses
// Too low: Audible noise, torque ripple
// Optimal: 10-25 kHz
// Audio: Must be above hearing range
// Humans hear up to ~20 kHz
// Use 40+ kHz for audio PWM
2. Filter PWM for Analog Output
PWM Pin ─── R ───┬─── Analog Output
│
C
│
GND
Cutoff Frequency = 1 / (2π × R × C)
Example: R=1kΩ, C=10μF
fc = 1 / (2π × 1000 × 0.00001) ≈ 16 Hz
3. Protect Inductive Loads
// Motors and solenoids are inductive
// Add flyback diode across load!
// Motor
// ┌────┴────┐
// PWM│ │GND
// │ ▼─ │
// └─────────┘
// Flyback Diode
4. Avoid PWM on Critical Pins
// Some Arduino pins share timers
// Changing frequency on one affects others!
// Pins 5 & 6 share Timer 0 (also used by millis/delay!)
// Pins 9 & 10 share Timer 1
// Pins 3 & 11 share Timer 2
// Changing Timer 0 frequency breaks millis() and delay()!
Common Issues and Debugging
Problem: LED Flickering
Causes: PWM frequency too low Solution: Increase frequency above 60 Hz (ideally 500 Hz+)
Problem: Motor Whining/Buzzing
Causes: PWM frequency in audible range Solution: Increase frequency to 20+ kHz
Problem: Servo Jittering
Causes: Incorrect pulse width or timing Solution: Use dedicated Servo library, ensure 50 Hz signal
Problem: PWM Not Working After Changing Frequency
Causes: Modified Timer 0 which breaks delay() and millis() Solution: Use different timer, or use external library
ELI10 (Explain Like I'm 10)
Imagine you have a light switch that you can flick on and off really, really fast - so fast that your eyes can't see it blinking!
PWM is like that super-fast blinking:
- If the light is ON for half the time and OFF for half the time, it looks 50% bright
- If it's ON for most of the time and OFF for a tiny bit, it looks almost fully bright
- If it's ON for only a tiny bit and OFF most of the time, it looks dim
This works because:
- Your eyes can't see things blinking faster than about 60 times per second
- So when we blink the light 500 or 1000 times per second, your brain sees a steady dimmed light!
The cool part?
- We're not actually reducing the voltage (which wastes energy as heat)
- We're just turning it on and off really fast (very efficient!)
- It's like running at full speed for short bursts vs. walking slowly all the time
Duty cycle is the percentage of time it's ON:
- 100% = always on (full brightness)
- 50% = on half the time (half brightness)
- 0% = always off (no light)
We use this same trick for controlling motor speeds, speakers, and lots of other things!
Further Resources
- Arduino PWM Guide
- Secrets of Arduino PWM
- ESP32 LEDC Documentation
- PWM Wikipedia
- Motor Control with PWM
ADC (Analog-to-Digital Converter)
Overview
An Analog-to-Digital Converter (ADC) is a hardware component that converts continuous analog signals (like voltage, temperature, light intensity) into discrete digital values that a microcontroller can process. ADCs are essential for interfacing with the real world, enabling microcontrollers to read sensors and analog inputs.
Key Concepts
What is an Analog Signal?
An analog signal is a continuous signal that can have any value within a range. Examples:
- Temperature: 0C to 100C
- Light intensity: 0 to maximum brightness
- Audio: continuous sound waves
- Voltage: 0V to 5V
What is a Digital Value?
A digital value is a discrete number that represents the analog signal:
- 8-bit ADC: 0 to 255 (256 possible values)
- 10-bit ADC: 0 to 1023 (1024 possible values)
- 12-bit ADC: 0 to 4095 (4096 possible values)
Resolution
Resolution determines how finely an ADC can distinguish between different analog values.
Formula:
Resolution = Reference Voltage / (2^n - 1)
Where n = number of bits
Examples:
| Bits | Levels | Resolution (5V ref) | Resolution (3.3V ref) |
|---|---|---|---|
| 8-bit | 256 | 19.6 mV | 12.9 mV |
| 10-bit | 1024 | 4.88 mV | 3.22 mV |
| 12-bit | 4096 | 1.22 mV | 0.81 mV |
| 16-bit | 65536 | 76.3 uV | 50.4 uV |
What this means: A 10-bit ADC with 5V reference can distinguish voltage differences as small as ~4.88mV.
Reference Voltage (VREF)
The reference voltage defines the maximum input voltage the ADC can measure.
- Arduino Uno: 5V (can use external ref)
- ESP32: 3.3V (default), 1.1V (attenuated)
- STM32: 3.3V (typically)
Important: Never exceed VREF on analog input pins!
Sampling Rate
How many times per second the ADC can take a measurement, measured in:
- SPS: Samples Per Second
- kSPS: Thousand samples per second
- MSPS: Million samples per second
Examples:
- Arduino Uno: ~10 kSPS
- ESP32: ~100 kSPS
- STM32F4: Up to 2.4 MSPS
- External ADC (ADS1115): 860 SPS max
How It Works
Conversion Process
- Sample: Capture the analog voltage at a specific moment
- Hold: Maintain that voltage level during conversion
- Quantize: Divide the voltage range into discrete levels
- Encode: Convert to a binary number
Analog Input (2.5V) -> ADC -> Digital Output (512 for 10-bit at 5V ref)
Calculation: 2.5V / 5V * 1023 = 511.5 ~ 512
Conversion Formula
Digital Value = (Analog Voltage / Reference Voltage) * (2^n - 1)
Analog Voltage = (Digital Value / (2^n - 1)) * Reference Voltage
Code Examples
Arduino (AVR) ADC
// Simple analog read
const int sensorPin = A0;
void setup() {
Serial.begin(9600);
// Optional: Set analog reference
// analogReference(DEFAULT); // 5V on Uno
// analogReference(INTERNAL); // 1.1V internal reference
// analogReference(EXTERNAL); // External AREF pin
}
void loop() {
// Read analog value (0-1023)
int rawValue = analogRead(sensorPin);
// Convert to voltage
float voltage = rawValue * (5.0 / 1023.0);
Serial.print("Raw: ");
Serial.print(rawValue);
Serial.print(" | Voltage: ");
Serial.print(voltage);
Serial.println(" V");
delay(500);
}
// Reading multiple analog pins
void readMultipleSensors() {
int sensors[] = {A0, A1, A2, A3};
for (int i = 0; i < 4; i++) {
int value = analogRead(sensors[i]);
float voltage = value * (5.0 / 1023.0);
Serial.print("Sensor ");
Serial.print(i);
Serial.print(": ");
Serial.println(voltage);
}
}
ESP32 ADC
// ESP32 has two ADC units with multiple channels
const int analogPin = 34; // ADC1_CH6 (GPIO 34)
void setup() {
Serial.begin(115200);
// Set ADC resolution (9-12 bits)
analogReadResolution(12); // Default is 12 bits (0-4095)
// Set ADC attenuation (changes measurement range)
// ADC_0db: 0-1.1V
// ADC_2_5db: 0-1.5V
// ADC_6db: 0-2.2V (default)
// ADC_11db: 0-3.3V
analogSetAttenuation(ADC_11db);
// Or set per pin
analogSetPinAttenuation(analogPin, ADC_11db);
}
void loop() {
int rawValue = analogRead(analogPin);
// Convert to voltage (with 11db attenuation, 0-3.3V range)
// Note: ESP32 ADC is non-linear, consider calibration
float voltage = rawValue * (3.3 / 4095.0);
Serial.print("Raw: ");
Serial.print(rawValue);
Serial.print(" | Voltage: ");
Serial.println(voltage);
delay(100);
}
// Better: Use calibrated read
#include "esp_adc_cal.h"
esp_adc_cal_characteristics_t adc_chars;
void setupCalibrated() {
esp_adc_cal_characterize(ADC_UNIT_1, ADC_ATTEN_DB_11,
ADC_WIDTH_BIT_12, 1100, &adc_chars);
}
void loopCalibrated() {
uint32_t voltage = analogRead(analogPin);
voltage = esp_adc_cal_raw_to_voltage(voltage, &adc_chars);
Serial.print("Calibrated voltage: ");
Serial.print(voltage);
Serial.println(" mV");
}
STM32 HAL ADC
#include "stm32f4xx_hal.h"
ADC_HandleTypeDef hadc1;
void ADC_Init(void) {
ADC_ChannelConfTypeDef sConfig = {0};
// Configure ADC
hadc1.Instance = ADC1;
hadc1.Init.ClockPrescaler = ADC_CLOCK_SYNC_PCLK_DIV4;
hadc1.Init.Resolution = ADC_RESOLUTION_12B;
hadc1.Init.ScanConvMode = DISABLE;
hadc1.Init.ContinuousConvMode = DISABLE;
hadc1.Init.DiscontinuousConvMode = DISABLE;
hadc1.Init.ExternalTrigConv = ADC_SOFTWARE_START;
hadc1.Init.DataAlign = ADC_DATAALIGN_RIGHT;
hadc1.Init.NbrOfConversion = 1;
HAL_ADC_Init(&hadc1);
// Configure channel
sConfig.Channel = ADC_CHANNEL_0;
sConfig.Rank = 1;
sConfig.SamplingTime = ADC_SAMPLETIME_84CYCLES;
HAL_ADC_ConfigChannel(&hadc1, &sConfig);
}
uint16_t ADC_Read(uint32_t channel) {
ADC_ChannelConfTypeDef sConfig = {0};
sConfig.Channel = channel;
sConfig.Rank = 1;
HAL_ADC_ConfigChannel(&hadc1, &sConfig);
// Start conversion
HAL_ADC_Start(&hadc1);
// Wait for conversion to complete
HAL_ADC_PollForConversion(&hadc1, 100);
// Read value
uint16_t value = HAL_ADC_GetValue(&hadc1);
return value;
}
float ADC_ReadVoltage(uint32_t channel) {
uint16_t raw = ADC_Read(channel);
// Convert to voltage (assuming 3.3V reference)
float voltage = (raw * 3.3f) / 4095.0f;
return voltage;
}
// DMA-based continuous conversion
uint16_t adc_buffer[16];
void ADC_Start_DMA(void) {
HAL_ADC_Start_DMA(&hadc1, (uint32_t*)adc_buffer, 16);
}
External ADC (ADS1115) via I2C
#include <Wire.h>
#include <Adafruit_ADS1X15.h>
Adafruit_ADS1115 ads; // 16-bit ADC
void setup() {
Serial.begin(115200);
// Initialize ADS1115
if (!ads.begin()) {
Serial.println("Failed to initialize ADS1115!");
while (1);
}
// Set gain
// ads.setGain(GAIN_TWOTHIRDS); // +/-6.144V range
// ads.setGain(GAIN_ONE); // +/-4.096V range
ads.setGain(GAIN_TWO); // +/-2.048V range (default)
// ads.setGain(GAIN_FOUR); // +/-1.024V range
// ads.setGain(GAIN_EIGHT); // +/-0.512V range
// ads.setGain(GAIN_SIXTEEN); // +/-0.256V range
}
void loop() {
// Read single-ended from channel 0
int16_t adc0 = ads.readADC_SingleEnded(0);
float voltage0 = ads.computeVolts(adc0);
// Read differential (channel 0 - channel 1)
int16_t diff01 = ads.readADC_Differential_0_1();
Serial.print("ADC0: ");
Serial.print(adc0);
Serial.print(" | Voltage: ");
Serial.println(voltage0);
delay(100);
}
Common Applications
1. Temperature Sensors (Thermistor)
const int thermistorPin = A0;
const float BETA = 3950; // Beta coefficient
const float R0 = 10000; // Resistance at 25C
const float T0 = 298.15; // 25C in Kelvin
float readTemperature() {
int raw = analogRead(thermistorPin);
// Convert to resistance
float R = 10000.0 * (1023.0 / raw - 1.0);
// Steinhart-Hart equation
float T = 1.0 / (1.0/T0 + (1.0/BETA) * log(R/R0));
return T - 273.15; // Convert to Celsius
}
2. Light Sensor (LDR/Photoresistor)
const int ldrPin = A1;
int readLightLevel() {
int rawValue = analogRead(ldrPin);
// Convert to percentage
int lightPercent = map(rawValue, 0, 1023, 0, 100);
return lightPercent;
}
3. Potentiometer (Volume Control)
const int potPin = A2;
const int ledPin = 9; // PWM pin
void setup() {
pinMode(ledPin, OUTPUT);
}
void loop() {
int potValue = analogRead(potPin);
// Map to PWM range (0-255)
int brightness = map(potValue, 0, 1023, 0, 255);
analogWrite(ledPin, brightness);
}
4. Battery Voltage Monitoring
const int batteryPin = A3;
const float voltageDividerRatio = 2.0; // R1=R2=10k
float readBatteryVoltage() {
int raw = analogRead(batteryPin);
// Convert to actual voltage
float adcVoltage = raw * (5.0 / 1023.0);
// Account for voltage divider
float batteryVoltage = adcVoltage * voltageDividerRatio;
return batteryVoltage;
}
void checkBattery() {
float voltage = readBatteryVoltage();
if (voltage < 3.3) {
Serial.println("WARNING: Low battery!");
}
}
5. Current Sensing (ACS712)
const int currentSensorPin = A4;
const float sensitivity = 0.185; // 185mV/A for ACS712-05B
float readCurrent() {
int raw = analogRead(currentSensorPin);
float voltage = raw * (5.0 / 1023.0);
// Zero point is 2.5V (Vcc/2)
float offsetVoltage = voltage - 2.5;
// Calculate current
float current = offsetVoltage / sensitivity;
return current;
}
Best Practices
1. Averaging for Stability
float readAverageAnalog(int pin, int samples = 10) {
long sum = 0;
for (int i = 0; i < samples; i++) {
sum += analogRead(pin);
delay(10); // Small delay between reads
}
return (float)sum / samples;
}
2. Handling Noise
// Software low-pass filter (running average)
const int numReadings = 10;
int readings[numReadings];
int readIndex = 0;
int total = 0;
int smoothedRead(int pin) {
total -= readings[readIndex];
readings[readIndex] = analogRead(pin);
total += readings[readIndex];
readIndex = (readIndex + 1) % numReadings;
return total / numReadings;
}
3. Proper Voltage Divider
// To measure higher voltages, use voltage divider
// Vin R1 , R2 GND
//
// ADC Pin
// Example: Measure 12V with 5V ADC
// R1 = 10k ohm, R2 = 7.5k ohm
// Vout = Vin * (R2 / (R1 + R2))
// Vout = 12V * (7.5 / 17.5) = 5.14V (slightly over, use 6.8k ohm for R2)
4. Calibration
struct CalibrationData {
float slope;
float offset;
};
CalibrationData calibrate(int pin, float knownVoltage) {
int rawValue = analogRead(pin);
CalibrationData cal;
cal.slope = knownVoltage / rawValue;
cal.offset = 0; // Adjust if needed
return cal;
}
float calibratedRead(int pin, CalibrationData cal) {
int raw = analogRead(pin);
return (raw * cal.slope) + cal.offset;
}
Common Issues and Debugging
Problem: Noisy Readings
Solutions:
- Add 0.1uF capacitor between analog pin and ground
- Use averaging/filtering in software
- Keep analog wires short and away from digital signals
- Use twisted pair cables for long runs
- Add ferrite beads on long cables
Problem: Incorrect Voltage Readings
Check:
- Verify reference voltage is correct
- Check voltage divider calculations
- Ensure input doesn't exceed VREF
- Verify ground connection
Problem: Slow Response
Solutions:
- Reduce averaging samples
- Check ADC clock/prescaler settings
- Use faster ADC if needed (external)
- Enable DMA for continuous sampling
ELI10 (Explain Like I'm 10)
Imagine you have a thermometer that shows any temperature between 0C and 100C, but you can only report whole numbers:
- If the real temperature is 23.7C, you might say "24C"
- If it's 23.2C, you might say "23C"
An ADC does the same thing! It takes a smooth, continuous voltage (like the temperature) and converts it to a number your microcontroller can understand.
Resolution is like how many different numbers you can say:
- 8-bit ADC: Can say 256 different numbers (0-255)
- 10-bit ADC: Can say 1024 different numbers (0-1023)
- 12-bit ADC: Can say 4096 different numbers (0-4095)
More bits = more precise measurements = seeing smaller differences!
Further Resources
- ADC Tutorial - SparkFun
- Arduino analogRead() Reference
- ESP32 ADC Documentation
- ADC Noise Reduction Techniques - AVR
- Understanding ADC Parameters
DAC (Digital-to-Analog Converter)
Overview
A Digital-to-Analog Converter (DAC) does the opposite of an ADC - it converts discrete digital values from a microcontroller into continuous analog voltage signals. DACs are essential for generating analog outputs like audio signals, control voltages, and waveforms.
Key Concepts
What Does a DAC Do?
A DAC takes a digital number and outputs a corresponding analog voltage:
Digital Input (512) -> DAC -> Analog Output (2.5V)
For 10-bit DAC with 5V reference:
Voltage = (512 / 1023) * 5V = 2.5V
Resolution
Just like ADCs, DAC resolution determines output precision:
| Bits | Levels | Voltage Step (5V) | Voltage Step (3.3V) |
|---|---|---|---|
| 8-bit | 256 | 19.6 mV | 12.9 mV |
| 10-bit | 1024 | 4.88 mV | 3.22 mV |
| 12-bit | 4096 | 1.22 mV | 0.81 mV |
| 16-bit | 65536 | 76.3 uV | 50.4 uV |
DAC vs PWM
Many microcontrollers don't have true DAC outputs, but can simulate analog using PWM:
| Feature | True DAC | PWM |
|---|---|---|
| Output | True analog voltage | Digital pulses |
| Smoothness | Smooth DC voltage | Requires filtering |
| Speed | Fast settling | Limited by PWM frequency |
| Filtering | Not needed | Low-pass filter needed |
| Complexity | Hardware DAC required | Any digital pin |
| Use Cases | Audio, precise control | LED dimming, motor speed |
How It Works
Conversion Formula
Output Voltage = (Digital Value / (2^n - 1)) * Reference Voltage
Where:
- n = number of bits
- Digital Value = input code (0 to 2^n - 1)
- Reference Voltage = max output voltage
Common DAC Architectures
- R-2R Ladder: Uses resistor network (simple, cheap)
- Binary Weighted: Uses weighted current sources
- Delta-Sigma: High resolution, used in audio
- String: Resistor divider network
Code Examples
Arduino Due (Built-in 12-bit DAC)
// Arduino Due has two DAC pins: DAC0 and DAC1
void setup() {
analogWriteResolution(12); // Set DAC resolution to 12 bits (0-4095)
}
void loop() {
// Output 1.65V on DAC0 (half of 3.3V reference)
analogWrite(DAC0, 2048); // 2048 / 4095 * 3.3V = 1.65V
delay(1000);
// Ramp voltage from 0V to 3.3V
for (int value = 0; value < 4096; value++) {
analogWrite(DAC0, value);
delayMicroseconds(100);
}
}
// Generate sine wave
void generateSineWave() {
const int samples = 100;
float frequency = 1000; // 1 kHz
for (int i = 0; i < samples; i++) {
float angle = (2.0 * PI * i) / samples;
int value = (sin(angle) + 1.0) * 2047.5; // Scale to 0-4095
analogWrite(DAC0, value);
delayMicroseconds(1000000 / (frequency * samples));
}
}
ESP32 (Built-in 8-bit DAC)
// ESP32 has two DAC channels: GPIO25 (DAC1) and GPIO26 (DAC2)
void setup() {
// No special initialization needed for DAC
}
void loop() {
// Output voltage (0-255 for 8-bit)
// 0 = 0V, 255 = 3.3V
dacWrite(25, 128); // Output ~1.65V on GPIO25
delay(1000);
}
// Generate sawtooth wave
void generateSawtoothWave() {
for (int value = 0; value < 256; value++) {
dacWrite(25, value);
delayMicroseconds(10);
}
}
// Generate triangle wave
void generateTriangleWave() {
// Rising edge
for (int value = 0; value < 256; value++) {
dacWrite(25, value);
delayMicroseconds(10);
}
// Falling edge
for (int value = 255; value >= 0; value--) {
dacWrite(25, value);
delayMicroseconds(10);
}
}
// Generate square wave
void generateSquareWave() {
dacWrite(25, 255); // HIGH
delay(1);
dacWrite(25, 0); // LOW
delay(1);
}
// Audio tone generation
void playTone(int frequency, int duration) {
const int samples = 32;
byte sineWave[samples];
// Pre-calculate sine wave
for (int i = 0; i < samples; i++) {
sineWave[i] = (sin(2.0 * PI * i / samples) + 1.0) * 127.5;
}
unsigned long startTime = millis();
int sampleDelay = 1000000 / (frequency * samples);
while (millis() - startTime < duration) {
for (int i = 0; i < samples; i++) {
dacWrite(25, sineWave[i]);
delayMicroseconds(sampleDelay);
}
}
}
STM32 HAL DAC
#include "stm32f4xx_hal.h"
DAC_HandleTypeDef hdac;
void DAC_Init(void) {
DAC_ChannelConfTypeDef sConfig = {0};
// Initialize DAC
hdac.Instance = DAC;
HAL_DAC_Init(&hdac);
// Configure DAC channel 1
sConfig.DAC_Trigger = DAC_TRIGGER_NONE;
sConfig.DAC_OutputBuffer = DAC_OUTPUTBUFFER_ENABLE;
HAL_DAC_ConfigChannel(&hdac, &sConfig, DAC_CHANNEL_1);
// Start DAC
HAL_DAC_Start(&hdac, DAC_CHANNEL_1);
}
void DAC_SetVoltage(float voltage) {
// Convert voltage to 12-bit value
// Assuming 3.3V reference
uint32_t value = (uint32_t)((voltage / 3.3f) * 4095.0f);
if (value > 4095) value = 4095;
HAL_DAC_SetValue(&hdac, DAC_CHANNEL_1, DAC_ALIGN_12B_R, value);
}
void DAC_SetValue(uint16_t value) {
HAL_DAC_SetValue(&hdac, DAC_CHANNEL_1, DAC_ALIGN_12B_R, value);
}
// DMA-based waveform generation
uint16_t sineWave[100];
void DAC_GenerateSineWave_DMA(void) {
// Pre-calculate sine wave
for (int i = 0; i < 100; i++) {
sineWave[i] = (uint16_t)((sin(2.0 * PI * i / 100.0) + 1.0) * 2047.5);
}
// Start DAC with DMA
HAL_DAC_Start_DMA(&hdac, DAC_CHANNEL_1, (uint32_t*)sineWave, 100,
DAC_ALIGN_12B_R);
// Configure timer to trigger DAC at specific rate
// This enables continuous waveform output
}
External DAC (MCP4725) via I2C
#include <Wire.h>
#include <Adafruit_MCP4725.h>
Adafruit_MCP4725 dac; // 12-bit DAC
void setup() {
Serial.begin(115200);
// Initialize MCP4725 (default address 0x62)
if (!dac.begin(0x62)) {
Serial.println("Failed to initialize MCP4725!");
while (1);
}
Serial.println("MCP4725 initialized!");
}
void loop() {
// Set voltage (0-4095 for 12-bit)
// Vout = (value / 4095) * Vdd
dac.setVoltage(2048, false); // Output ~1.65V (Vdd/2)
delay(1000);
}
// Ramp voltage smoothly
void rampVoltage(uint16_t start, uint16_t end, uint16_t steps) {
int16_t increment = (end - start) / steps;
for (uint16_t i = 0; i < steps; i++) {
uint16_t value = start + (i * increment);
dac.setVoltage(value, false);
delay(10);
}
}
// Generate precise voltage
void setVoltage(float voltage) {
// Assuming 5V Vdd
uint16_t value = (uint16_t)((voltage / 5.0) * 4095.0);
dac.setVoltage(value, false);
}
// Store value in EEPROM (survives power cycle)
void saveVoltage(uint16_t value) {
dac.setVoltage(value, true); // true = write to EEPROM
}
PWM as Pseudo-DAC (Arduino Uno)
// Arduino Uno doesn't have true DAC, use PWM with filtering
const int pwmPin = 9; // Any PWM pin
void setup() {
pinMode(pwmPin, OUTPUT);
// Increase PWM frequency for smoother output
// Default: 490 Hz for pins 5,6 and 980 Hz for others
// Setting for pin 9 and 10:
TCCR1B = TCCR1B & 0b11111000 | 0x01; // 31.25 kHz
}
void loop() {
// Output 2.5V (50% duty cycle with 5V Vdd)
analogWrite(pwmPin, 128); // 0-255 range
delay(1000);
}
// Hardware low-pass filter (required for PWM DAC):
// PWM Pin ----1kohm----, Output
// |
// 10uF
// |
// GND
//
// Cutoff frequency = 1 / (2*pi * R * C) = ~16 Hz
// Convert voltage to PWM value
void setPWMVoltage(float voltage) {
int pwmValue = (int)((voltage / 5.0) * 255.0);
analogWrite(pwmPin, constrain(pwmValue, 0, 255));
}
Common Applications
1. Audio Output
// Simple audio playback
const byte audioSample[] = {128, 150, 172, 192, 209, ...};
const int sampleRate = 8000; // 8 kHz
void playAudio() {
for (int i = 0; i < sizeof(audioSample); i++) {
dacWrite(25, audioSample[i]);
delayMicroseconds(1000000 / sampleRate);
}
}
2. Voltage Reference Generation
// Generate precise reference voltage
void setReferenceVoltage(float voltage) {
// Using 12-bit DAC with 3.3V reference
uint16_t value = (uint16_t)((voltage / 3.3) * 4095);
analogWrite(DAC0, value);
}
// Example: Generate 1.024V reference
void setup() {
analogWriteResolution(12);
setReferenceVoltage(1.024); // Output constant 1.024V
}
3. Motor Speed Control
// Control motor speed with voltage
void setMotorSpeed(int speedPercent) {
// 0% = 0V, 100% = 3.3V
int dacValue = map(speedPercent, 0, 100, 0, 255);
dacWrite(25, dacValue);
}
4. LED Brightness (True Analog)
// Unlike PWM, DAC gives true DC voltage
void setLEDBrightness(int percent) {
int dacValue = map(percent, 0, 100, 0, 255);
dacWrite(25, dacValue);
// No flickering or PWM noise!
}
5. Signal Generation for Testing
// Generate test signals
void generateDCOffset(float voltage) {
uint16_t value = (uint16_t)((voltage / 3.3) * 4095);
analogWrite(DAC0, value);
}
// Programmable voltage divider
void setProgrammableVoltage(float targetVoltage) {
if (targetVoltage <= 3.3) {
generateDCOffset(targetVoltage);
}
}
Waveform Generation
Pre-calculated Waveform Tables
// Sine wave lookup table (256 samples)
const uint8_t sineTable[256] PROGMEM = {
127, 130, 133, 136, 139, 143, 146, 149,
152, 155, 158, 161, 164, 167, 170, 173,
// ... full 256 values
};
void generateSineFromTable(int frequency) {
int delayTime = 1000000 / (frequency * 256);
for (int i = 0; i < 256; i++) {
uint8_t value = pgm_read_byte(&sineTable[i]);
dacWrite(25, value);
delayMicroseconds(delayTime);
}
}
Best Practices
1. Output Filtering
For cleaner output, add RC low-pass filter:
DAC Out ----100ohm----, Output
|
100nF
|
GND
2. Buffering
For driving loads, add op-amp buffer:
DAC Out ----, Op-Amp ---- Output
| |
+--------+
Feedback
3. Settling Time
// Allow settling time after DAC update
void setDACWithSettling(uint16_t value) {
analogWrite(DAC0, value);
delayMicroseconds(10); // Wait for output to settle
}
4. Reference Voltage Stability
// Use external voltage reference for precision
// Internal reference can drift with temperature
Common Issues and Debugging
Problem: Output Voltage Incorrect
Check:
- Verify reference voltage
- Check calculation: (value / max) * Vref
- Ensure value doesn't exceed maximum
- Measure with high-impedance multimeter
Problem: Noisy Output
Solutions:
- Add output filter capacitor (100nF)
- Use separate analog ground
- Add decoupling caps near DAC (0.1uF)
- Keep output wires short
Problem: Can't Drive Load
Solutions:
- DAC outputs have limited current capability (~20mA typical)
- Add op-amp buffer for higher current
- Use darlington transistor for heavy loads
Problem: Distorted Waveforms
Check:
- Update rate too slow for frequency
- Insufficient sample resolution
- Loading effect (add buffer)
DAC Specifications to Consider
1. Resolution
- More bits = finer voltage control
- 8-bit usually sufficient for simple control
- 12-16 bit for audio and precision apps
2. Settling Time
- Time to reach final value
- Important for high-speed applications
- Typical: 1-10 us
3. Output Range
- Single-ended: 0V to Vref
- Bipolar: -Vref to +Vref (requires special circuit)
4. Update Rate
- How fast can DAC change values
- Audio: >40 kSPS
- Simple control: <1 kSPS
ELI10 (Explain Like I'm 10)
Remember ADC is like a thermometer that converts smooth temperatures to numbers? DAC is the opposite!
Imagine you have a light dimmer switch:
- Instead of smoothly turning the knob, you can only pick from specific positions
- 8-bit DAC: You have 256 positions (0-255)
- 12-bit DAC: You have 4096 positions (way more precise!)
The DAC takes your number choice and creates a voltage:
- Digital number 0 -> 0 volts
- Digital number 128 (half) -> 1.65 volts
- Digital number 255 (max) -> 3.3 volts
It's like having a volume knob that you control with numbers instead of turning it by hand!
PWM vs DAC: PWM is like flashing a light super fast to make it look dimmer. DAC is like actually turning down the voltage - it's smoother and better for some jobs!
Further Resources
- DAC Tutorial - SparkFun
- Arduino Due DAC Reference
- ESP32 DAC Documentation
- MCP4725 Datasheet
- Audio with Arduino DAC
Real-Time Clock (RTC) Modules
Comprehensive guide to RTC modules including DS1307, DS3231, and implementation examples.
Table of Contents
Introduction
Real-Time Clock (RTC) modules are specialized integrated circuits that keep accurate time even when the main system is powered off. They are essential for data logging, scheduling, timestamps, and time-based applications.
Why Use an RTC Module?
- Accurate Timekeeping: Crystal oscillator provides precise time
- Low Power: Runs on backup battery for years
- Independent Operation: Maintains time when main power is off
- Calendar Functions: Handles dates, months, leap years automatically
- Alarms: Can trigger events at specific times
Popular RTC Modules
| Module | Crystal | Accuracy | Battery | Temperature | I2C Addr | Price |
|---|---|---|---|---|---|---|
| DS1307 | 32.768 kHz | ±2 min/month | CR2032 | No | 0x68 | $1 |
| DS3231 | 32.768 kHz (TCXO) | ±2 min/year | CR2032 | Yes | 0x68 | $2-5 |
| PCF8523 | 32.768 kHz | ±3 min/year | CR2032 | No | 0x68 | $2 |
| MCP7940N | 32.768 kHz | ±2 min/month | CR2032 | No | 0x6F | $1 |
RTC Basics
Time Representation
RTCs store time in BCD (Binary Coded Decimal) format:
Decimal 59 = 0101 1001 BCD
5 9
Decimal to BCD: 59 = (5 << 4) | 9 = 0x59
BCD to Decimal: 0x59 = ((0x59 >> 4) * 10) + (0x59 & 0x0F) = 59
BCD Conversion Functions
// Decimal to BCD
uint8_t dec_to_bcd(uint8_t val) {
return ((val / 10) << 4) | (val % 10);
}
// BCD to Decimal
uint8_t bcd_to_dec(uint8_t val) {
return ((val >> 4) * 10) + (val & 0x0F);
}
I2C Communication
All popular RTC modules use I2C interface:
Connections:
RTC VCC -> 3.3V or 5V
RTC GND -> GND
RTC SDA -> SDA (with pull-up resistor)
RTC SCL -> SCL (with pull-up resistor)
Pull-up resistors: 4.7kΩ typical
Wiring Diagram:
RTC Module Microcontroller
┌────┐
VCC ┤ ├─ VCC (3.3V/5V)
GND ┤ ├─ GND
SDA ┤ ├─ SDA (with 4.7kΩ pull-up)
SCL ┤ ├─ SCL (with 4.7kΩ pull-up)
└────┘
DS1307
Features
- Accuracy: ±2 minutes per month
- Operating Voltage: 4.5-5.5V (5V recommended)
- Battery Backup: CR2032 (typical)
- Interface: I2C (100 kHz)
- Address: 0x68 (fixed)
- RAM: 56 bytes of non-volatile SRAM
- Output: 1 Hz square wave
Register Map
Register Function
0x00 Seconds (00-59)
0x01 Minutes (00-59)
0x02 Hours (00-23 or 01-12)
0x03 Day of week (1-7)
0x04 Date (01-31)
0x05 Month (01-12)
0x06 Year (00-99)
0x07 Control (SQW output)
0x08-0x3F RAM (56 bytes)
Bit Layout:
Seconds: 0 | 10-sec | sec
CH | 4 2 1 | 8 4 2 1
CH = Clock Halt bit (0 = running, 1 = stopped)
Arduino DS1307 Library
#include <Wire.h>
#include <RTClib.h>
RTC_DS1307 rtc;
void setup() {
Serial.begin(9600);
Wire.begin();
if (!rtc.begin()) {
Serial.println("Couldn't find RTC");
while (1);
}
if (!rtc.isrunning()) {
Serial.println("RTC is NOT running, setting time...");
// Set to compile time
rtc.adjust(DateTime(F(__DATE__), F(__TIME__)));
// Or set manually:
// rtc.adjust(DateTime(2024, 1, 15, 12, 30, 0));
}
}
void loop() {
DateTime now = rtc.now();
Serial.print(now.year(), DEC);
Serial.print('/');
Serial.print(now.month(), DEC);
Serial.print('/');
Serial.print(now.day(), DEC);
Serial.print(" ");
Serial.print(now.hour(), DEC);
Serial.print(':');
Serial.print(now.minute(), DEC);
Serial.print(':');
Serial.println(now.second(), DEC);
delay(1000);
}
DS1307 Bare Metal (Arduino)
#include <Wire.h>
#define DS1307_ADDR 0x68
uint8_t dec_to_bcd(uint8_t val) {
return ((val / 10) << 4) | (val % 10);
}
uint8_t bcd_to_dec(uint8_t val) {
return ((val >> 4) * 10) + (val & 0x0F);
}
void ds1307_set_time(uint8_t hour, uint8_t min, uint8_t sec) {
Wire.beginTransmission(DS1307_ADDR);
Wire.write(0x00); // Start at seconds register
Wire.write(dec_to_bcd(sec) & 0x7F); // Clear CH bit
Wire.write(dec_to_bcd(min));
Wire.write(dec_to_bcd(hour));
Wire.endTransmission();
}
void ds1307_set_date(uint8_t day, uint8_t date, uint8_t month, uint8_t year) {
Wire.beginTransmission(DS1307_ADDR);
Wire.write(0x03); // Start at day register
Wire.write(dec_to_bcd(day));
Wire.write(dec_to_bcd(date));
Wire.write(dec_to_bcd(month));
Wire.write(dec_to_bcd(year));
Wire.endTransmission();
}
void ds1307_read_time(uint8_t *hour, uint8_t *min, uint8_t *sec) {
Wire.beginTransmission(DS1307_ADDR);
Wire.write(0x00); // Start at seconds register
Wire.endTransmission();
Wire.requestFrom(DS1307_ADDR, 3);
*sec = bcd_to_dec(Wire.read() & 0x7F);
*min = bcd_to_dec(Wire.read());
*hour = bcd_to_dec(Wire.read());
}
void setup() {
Serial.begin(9600);
Wire.begin();
// Set time: 12:30:00
ds1307_set_time(12, 30, 0);
// Set date: Monday, 15/01/24
ds1307_set_date(1, 15, 1, 24);
}
void loop() {
uint8_t hour, min, sec;
ds1307_read_time(&hour, &min, &sec);
Serial.print(hour);
Serial.print(":");
Serial.print(min);
Serial.print(":");
Serial.println(sec);
delay(1000);
}
DS3231
Features
- Accuracy: ±2 minutes per year (much better than DS1307)
- Temperature Compensated: TCXO provides better accuracy
- Operating Voltage: 2.3-5.5V
- Battery Backup: CR2032
- Interface: I2C (100-400 kHz)
- Address: 0x68 (fixed)
- Temperature Sensor: Built-in (±3°C accuracy)
- Alarms: Two programmable alarms
- Square Wave Output: 1Hz, 1.024kHz, 4.096kHz, 8.192kHz
Register Map
Register Function
0x00 Seconds (00-59)
0x01 Minutes (00-59)
0x02 Hours (00-23 or 01-12)
0x03 Day of week (1-7)
0x04 Date (01-31)
0x05 Month/Century (01-12)
0x06 Year (00-99)
0x07-0x0A Alarm 1
0x0B-0x0D Alarm 2
0x0E Control
0x0F Control/Status
0x10 Aging offset
0x11-0x12 Temperature
Arduino DS3231 Library
#include <Wire.h>
#include <RTClib.h>
RTC_DS3231 rtc;
void setup() {
Serial.begin(9600);
Wire.begin();
if (!rtc.begin()) {
Serial.println("Couldn't find RTC");
while (1);
}
if (rtc.lostPower()) {
Serial.println("RTC lost power, setting time...");
rtc.adjust(DateTime(F(__DATE__), F(__TIME__)));
}
}
void loop() {
DateTime now = rtc.now();
// Print time
Serial.print(now.year(), DEC);
Serial.print('/');
Serial.print(now.month(), DEC);
Serial.print('/');
Serial.print(now.day(), DEC);
Serial.print(" ");
Serial.print(now.hour(), DEC);
Serial.print(':');
Serial.print(now.minute(), DEC);
Serial.print(':');
Serial.print(now.second(), DEC);
// Print temperature
float temp = rtc.getTemperature();
Serial.print(" Temp: ");
Serial.print(temp);
Serial.println("°C");
delay(1000);
}
DS3231 Alarm Example
#include <Wire.h>
#include <RTClib.h>
RTC_DS3231 rtc;
void setup() {
Serial.begin(9600);
Wire.begin();
rtc.begin();
// Set alarm 1 for every day at 12:30:00
rtc.setAlarm1(DateTime(0, 0, 0, 12, 30, 0), DS3231_A1_Hour);
// Enable alarm interrupt
rtc.disableAlarm(1);
rtc.disableAlarm(2);
rtc.clearAlarm(1);
rtc.clearAlarm(2);
rtc.writeSqwPinMode(DS3231_OFF); // Disable square wave
}
void loop() {
if (rtc.alarmFired(1)) {
Serial.println("Alarm 1 triggered!");
rtc.clearAlarm(1);
}
delay(1000);
}
DS3231 Temperature Reading
#include <Wire.h>
#define DS3231_ADDR 0x68
#define TEMP_MSB 0x11
#define TEMP_LSB 0x12
float ds3231_get_temperature() {
Wire.beginTransmission(DS3231_ADDR);
Wire.write(TEMP_MSB);
Wire.endTransmission();
Wire.requestFrom(DS3231_ADDR, 2);
uint8_t msb = Wire.read();
uint8_t lsb = Wire.read();
// Combine MSB and LSB
int16_t temp = (msb << 2) | (lsb >> 6);
// Handle negative temperatures
if (temp & 0x200) {
temp |= 0xFC00;
}
return temp * 0.25;
}
void setup() {
Serial.begin(9600);
Wire.begin();
}
void loop() {
float temperature = ds3231_get_temperature();
Serial.print("Temperature: ");
Serial.print(temperature);
Serial.println("°C");
delay(1000);
}
PCF8523
Features
- Accuracy: ±3 minutes per year
- Operating Voltage: 1.8-5.5V
- Battery Backup: CR2032
- Interface: I2C (100-400 kHz)
- Address: 0x68 (fixed)
- Alarm: Single programmable alarm
- Timer: Countdown timer
- Low Power: Multiple power-saving modes
Arduino PCF8523
#include <Wire.h>
#include <RTClib.h>
RTC_PCF8523 rtc;
void setup() {
Serial.begin(9600);
Wire.begin();
if (!rtc.begin()) {
Serial.println("Couldn't find RTC");
while (1);
}
if (!rtc.initialized() || rtc.lostPower()) {
Serial.println("RTC is NOT initialized, setting time...");
rtc.adjust(DateTime(F(__DATE__), F(__TIME__)));
}
// Start RTC
rtc.start();
}
void loop() {
DateTime now = rtc.now();
Serial.print(now.year(), DEC);
Serial.print('/');
Serial.print(now.month(), DEC);
Serial.print('/');
Serial.print(now.day(), DEC);
Serial.print(" ");
Serial.print(now.hour(), DEC);
Serial.print(':');
Serial.print(now.minute(), DEC);
Serial.print(':');
Serial.println(now.second(), DEC);
delay(1000);
}
Arduino Examples
Data Logger with RTC
#include <Wire.h>
#include <RTClib.h>
#include <SD.h>
RTC_DS3231 rtc;
const int CS_PIN = 10;
void setup() {
Serial.begin(9600);
Wire.begin();
if (!rtc.begin()) {
Serial.println("RTC error");
while (1);
}
if (!SD.begin(CS_PIN)) {
Serial.println("SD card error");
while (1);
}
}
void loop() {
DateTime now = rtc.now();
float temp = rtc.getTemperature();
// Create filename
char filename[13];
sprintf(filename, "%04d%02d%02d.txt",
now.year(), now.month(), now.day());
// Open file
File dataFile = SD.open(filename, FILE_WRITE);
if (dataFile) {
// Write timestamp and data
dataFile.print(now.hour());
dataFile.print(":");
dataFile.print(now.minute());
dataFile.print(":");
dataFile.print(now.second());
dataFile.print(",");
dataFile.println(temp);
dataFile.close();
Serial.println("Data logged");
} else {
Serial.println("Error opening file");
}
delay(60000); // Log every minute
}
Digital Clock Display
#include <Wire.h>
#include <RTClib.h>
#include <LiquidCrystal.h>
RTC_DS3231 rtc;
LiquidCrystal lcd(12, 11, 5, 4, 3, 2);
void setup() {
Wire.begin();
rtc.begin();
lcd.begin(16, 2);
if (rtc.lostPower()) {
rtc.adjust(DateTime(F(__DATE__), F(__TIME__)));
}
}
void loop() {
DateTime now = rtc.now();
// Display date on line 1
lcd.setCursor(0, 0);
lcd.print(now.day(), DEC);
lcd.print('/');
lcd.print(now.month(), DEC);
lcd.print('/');
lcd.print(now.year(), DEC);
lcd.print(" ");
// Display time on line 2
lcd.setCursor(0, 1);
if (now.hour() < 10) lcd.print('0');
lcd.print(now.hour(), DEC);
lcd.print(':');
if (now.minute() < 10) lcd.print('0');
lcd.print(now.minute(), DEC);
lcd.print(':');
if (now.second() < 10) lcd.print('0');
lcd.print(now.second(), DEC);
delay(1000);
}
Alarm Clock
#include <Wire.h>
#include <RTClib.h>
RTC_DS3231 rtc;
const int BUZZER_PIN = 9;
const int BUTTON_PIN = 2;
uint8_t alarm_hour = 7;
uint8_t alarm_minute = 30;
bool alarm_active = false;
void setup() {
Serial.begin(9600);
Wire.begin();
rtc.begin();
pinMode(BUZZER_PIN, OUTPUT);
pinMode(BUTTON_PIN, INPUT_PULLUP);
// Set alarm
rtc.setAlarm1(DateTime(0, 0, 0, alarm_hour, alarm_minute, 0),
DS3231_A1_Hour);
}
void loop() {
DateTime now = rtc.now();
// Check alarm
if (rtc.alarmFired(1)) {
alarm_active = true;
rtc.clearAlarm(1);
}
// Sound buzzer if alarm active
if (alarm_active) {
tone(BUZZER_PIN, 1000, 500);
delay(1000);
// Check for button press to stop
if (digitalRead(BUTTON_PIN) == LOW) {
alarm_active = false;
noTone(BUZZER_PIN);
}
}
// Display time
Serial.print(now.hour());
Serial.print(":");
Serial.print(now.minute());
Serial.print(":");
Serial.println(now.second());
delay(1000);
}
STM32 Examples
DS3231 with STM32 HAL
#include "main.h"
#include <stdio.h>
I2C_HandleTypeDef hi2c1;
#define DS3231_ADDR (0x68 << 1)
uint8_t dec_to_bcd(uint8_t val) {
return ((val / 10) << 4) | (val % 10);
}
uint8_t bcd_to_dec(uint8_t val) {
return ((val >> 4) * 10) + (val & 0x0F);
}
void ds3231_set_time(uint8_t hour, uint8_t min, uint8_t sec) {
uint8_t data[4];
data[0] = 0x00; // Start register
data[1] = dec_to_bcd(sec);
data[2] = dec_to_bcd(min);
data[3] = dec_to_bcd(hour);
HAL_I2C_Master_Transmit(&hi2c1, DS3231_ADDR, data, 4, HAL_MAX_DELAY);
}
void ds3231_read_time(uint8_t *hour, uint8_t *min, uint8_t *sec) {
uint8_t reg = 0x00;
uint8_t data[3];
HAL_I2C_Master_Transmit(&hi2c1, DS3231_ADDR, ®, 1, HAL_MAX_DELAY);
HAL_I2C_Master_Receive(&hi2c1, DS3231_ADDR, data, 3, HAL_MAX_DELAY);
*sec = bcd_to_dec(data[0]);
*min = bcd_to_dec(data[1]);
*hour = bcd_to_dec(data[2]);
}
float ds3231_get_temperature(void) {
uint8_t reg = 0x11;
uint8_t data[2];
HAL_I2C_Master_Transmit(&hi2c1, DS3231_ADDR, ®, 1, HAL_MAX_DELAY);
HAL_I2C_Master_Receive(&hi2c1, DS3231_ADDR, data, 2, HAL_MAX_DELAY);
int16_t temp = (data[0] << 2) | (data[1] >> 6);
if (temp & 0x200) {
temp |= 0xFC00;
}
return temp * 0.25;
}
int main(void) {
HAL_Init();
SystemClock_Config();
MX_I2C1_Init();
MX_USART1_UART_Init();
// Set initial time
ds3231_set_time(12, 30, 0);
while (1) {
uint8_t hour, min, sec;
ds3231_read_time(&hour, &min, &sec);
float temp = ds3231_get_temperature();
printf("%02d:%02d:%02d Temp: %.2f°C\r\n", hour, min, sec, temp);
HAL_Delay(1000);
}
}
AVR Bare Metal
DS1307 with AVR (ATmega328P)
#include <avr/io.h>
#include <util/delay.h>
#include <stdio.h>
#define DS1307_ADDR 0x68
#define F_SCL 100000UL
#define TWI_BITRATE ((F_CPU / F_SCL) - 16) / 2
uint8_t dec_to_bcd(uint8_t val) {
return ((val / 10) << 4) | (val % 10);
}
uint8_t bcd_to_dec(uint8_t val) {
return ((val >> 4) * 10) + (val & 0x0F);
}
void i2c_init(void) {
TWBR = (uint8_t)TWI_BITRATE;
TWCR = (1 << TWEN);
}
void i2c_start(void) {
TWCR = (1 << TWINT) | (1 << TWSTA) | (1 << TWEN);
while (!(TWCR & (1 << TWINT)));
}
void i2c_stop(void) {
TWCR = (1 << TWINT) | (1 << TWSTO) | (1 << TWEN);
}
void i2c_write(uint8_t data) {
TWDR = data;
TWCR = (1 << TWINT) | (1 << TWEN);
while (!(TWCR & (1 << TWINT)));
}
uint8_t i2c_read_ack(void) {
TWCR = (1 << TWINT) | (1 << TWEN) | (1 << TWEA);
while (!(TWCR & (1 << TWINT)));
return TWDR;
}
uint8_t i2c_read_nack(void) {
TWCR = (1 << TWINT) | (1 << TWEN);
while (!(TWCR & (1 << TWINT)));
return TWDR;
}
void ds1307_set_time(uint8_t hour, uint8_t min, uint8_t sec) {
i2c_start();
i2c_write((DS1307_ADDR << 1) | 0);
i2c_write(0x00); // Start register
i2c_write(dec_to_bcd(sec));
i2c_write(dec_to_bcd(min));
i2c_write(dec_to_bcd(hour));
i2c_stop();
}
void ds1307_read_time(uint8_t *hour, uint8_t *min, uint8_t *sec) {
i2c_start();
i2c_write((DS1307_ADDR << 1) | 0);
i2c_write(0x00); // Start register
i2c_start(); // Repeated start
i2c_write((DS1307_ADDR << 1) | 1);
*sec = bcd_to_dec(i2c_read_ack());
*min = bcd_to_dec(i2c_read_ack());
*hour = bcd_to_dec(i2c_read_nack());
i2c_stop();
}
int main(void) {
i2c_init();
uart_init(); // Assume UART is initialized
// Set time to 12:30:00
ds1307_set_time(12, 30, 0);
while (1) {
uint8_t hour, min, sec;
ds1307_read_time(&hour, &min, &sec);
printf("%02d:%02d:%02d\n", hour, min, sec);
_delay_ms(1000);
}
return 0;
}
Best Practices
- Battery Backup: Always install backup battery for continuous operation
- Pull-up Resistors: Ensure 4.7kΩ pull-ups on SDA and SCL
- Power Supply: DS1307 requires 5V, DS3231 works with 3.3V-5V
- Initial Setup: Set time after first power-on or battery change
- Lost Power Check: Check and handle RTC power loss
- BCD Format: Remember to convert between decimal and BCD
- I2C Speed: Use 100 kHz for reliability, 400 kHz if needed
Troubleshooting
Common Issues
RTC Not Responding:
- Check I2C address (usually 0x68)
- Verify SDA/SCL connections
- Ensure pull-up resistors present
- Check power supply voltage
Time Not Keeping:
- Install backup battery (CR2032)
- Check battery voltage (should be ~3V)
- For DS1307: Clear CH (Clock Halt) bit
- Verify crystal oscillator is working
Inaccurate Time:
- DS1307: Normal (±2 min/month), consider DS3231
- DS3231: Check temperature effects
- Calibrate using aging offset register (DS3231)
I2C Communication Errors:
// Check I2C scanner result
Wire.beginTransmission(0x68);
if (Wire.endTransmission() == 0) {
Serial.println("RTC found at 0x68");
} else {
Serial.println("RTC not found");
}
Resources
- DS1307 Datasheet: Maxim Integrated
- DS3231 Datasheet: Maxim Integrated
- RTClib Library: https://github.com/adafruit/RTClib
- I2C Protocol: See I2C documentation
See Also
GPIO
General Purpose Input/Output (GPIO)
GPIO stands for General Purpose Input/Output. It is a generic pin on an integrated circuit or computer board whose behavior (including whether it is an input or output pin) can be controlled by the user at runtime. GPIO pins are a staple in embedded systems and microcontroller projects due to their versatility and ease of use.
Key Features of GPIO
-
Configurable Direction: Each GPIO pin can be configured as either an input or an output. This allows the pin to either read signals from external devices (input) or send signals to external devices (output).
-
Digital Signals: GPIO pins typically handle digital signals, meaning they can be in one of two states: high (1) or low (0). The voltage levels corresponding to these states depend on the specific hardware but are commonly 3.3V or 5V for high and 0V for low.
-
Interrupts: Many GPIO pins support interrupts, which allow the pin to trigger an event in the software when a specific condition is met, such as a change in state. This is useful for responding to external events without constantly polling the pin.
-
Pull-up/Pull-down Resistors: GPIO pins often have configurable pull-up or pull-down resistors. These resistors ensure that the pin is in a known state (high or low) when it is not actively being driven by an external source.
-
Debouncing: When reading input from mechanical switches, GPIO pins can experience noise or "bouncing." Debouncing techniques, either in hardware or software, are used to ensure that the signal is stable and accurate.
Common Uses of GPIO
- LED Control: Turning LEDs on and off or controlling their brightness using Pulse Width Modulation (PWM).
- Button Inputs: Reading the state of buttons or switches to trigger actions in the software.
- Sensor Interfacing: Reading data from various sensors like temperature, humidity, or motion sensors.
- Communication: Implementing simple communication protocols like I2C, SPI, or UART using GPIO pins.
Example Code
Here is an example of how to configure and use a GPIO pin in a typical microcontroller environment (e.g., using the Arduino platform):
// Define the pin number
const int ledPin = 13; // Pin number for the LED
void setup() {
// Initialize the digital pin as an output.
pinMode(ledPin, OUTPUT);
}
void loop() {
// Turn the LED on (HIGH is the voltage level)
digitalWrite(ledPin, HIGH);
// Wait for a second
delay(1000);
// Turn the LED off by making the voltage LOW
digitalWrite(ledPin, LOW);
// Wait for a second
delay(1000);
}
Interrupts
Overview
Interrupts are signals that temporarily halt the normal execution of a program or process, allowing the system to respond to important events. They are a crucial mechanism in computer architecture, enabling efficient multitasking and real-time processing.
Types of Interrupts
-
Hardware Interrupts: Generated by hardware devices (e.g., keyboard, mouse, network cards) to signal that they require attention from the CPU. These interrupts can occur at any time and are typically prioritized to ensure that critical tasks are handled promptly.
-
Software Interrupts: Triggered by software instructions, such as system calls or exceptions. These interrupts allow programs to request services from the operating system or handle errors gracefully.
-
Timer Interrupts: Generated by a timer within the system to allow the operating system to perform regular tasks, such as scheduling processes and managing system resources.
Interrupt Handling
When an interrupt occurs, the CPU stops executing the current program and saves its state. The system then executes an interrupt handler, a special routine designed to address the specific interrupt. After the handler completes its task, the CPU restores the saved state and resumes the interrupted program.
Applications of Interrupts
-
Real-Time Systems: Interrupts are essential in real-time systems where timely responses to events are critical, such as in embedded systems, automotive applications, and industrial automation.
-
Multitasking: Operating systems use interrupts to manage multiple processes efficiently, allowing them to share CPU time and resources without significant delays.
-
Event-Driven Programming: In event-driven architectures, interrupts facilitate the handling of user inputs and other events, enabling responsive applications.
Conclusion
Understanding interrupts is vital for developers working with low-level programming, operating systems, and embedded systems. They play a key role in ensuring that systems can respond quickly and efficiently to a variety of events.
Timers and Counters
Overview
Timers and counters are essential hardware peripherals in microcontrollers that keep track of time, count events, generate precise delays, create PWM signals, and trigger interrupts at specific intervals. Unlike software delays (which block the CPU), hardware timers run independently, allowing your program to multitask efficiently.
Key Concepts
Timer vs Counter
| Feature | Timer | Counter |
|---|---|---|
| Clock Source | Internal (system clock) | External (GPIO pin) |
| Purpose | Measure time intervals | Count external events |
| Speed | Fixed by clock | Variable (event-driven) |
| Example Use | Generate 1ms interrupts | Count encoder pulses |
Timer Components
- Counter Register: Stores current count value
- Prescaler: Divides input clock to slow down counting
- Compare Register: Value to trigger events when matched
- Auto-reload Register: Value to reset counter to (for periodic timers)
Timer Modes
- Basic Timer: Simple counting up or down
- PWM Mode: Generate pulse-width modulated signals
- Input Capture: Measure external signal timing
- Output Compare: Trigger events at specific times
- Encoder Mode: Read quadrature encoders
How It Works
Clock and Prescaler
System Clock (16 MHz)
->
Prescaler (/256)
->
Timer Clock (62.5 kHz)
->
Counter increments at 62.5 kHz
Formula:
Timer Frequency = CPU Frequency / Prescaler
Timer Period = 1 / Timer Frequency
Overflow Time = (2^bits / Timer Frequency)
Example (Arduino Uno - 16 MHz):
Prescaler = 256
Timer Frequency = 16,000,000 / 256 = 62,500 Hz
Timer Period = 1 / 62,500 = 16 us per tick
For 8-bit timer (0-255):
Overflow Time = 256 * 16 us = 4.096 ms
For 16-bit timer (0-65535):
Overflow Time = 65,536 * 16 us = 1.048 seconds
Code Examples
Arduino Timer Interrupt
// Using Timer1 (16-bit) for 1ms interrupt
volatile unsigned long millisCounter = 0;
void setup() {
Serial.begin(9600);
// Stop interrupts during setup
cli();
// Reset Timer1
TCCR1A = 0;
TCCR1B = 0;
TCNT1 = 0;
// Set compare match register for 1ms
// OCR1A = (16MHz / (prescaler * desired frequency)) - 1
// OCR1A = (16,000,000 / (64 * 1000)) - 1 = 249
OCR1A = 249;
// Turn on CTC mode (Clear Timer on Compare Match)
TCCR1B |= (1 << WGM12);
// Set CS11 and CS10 bits for 64 prescaler
TCCR1B |= (1 << CS11) | (1 << CS10);
// Enable timer compare interrupt
TIMSK1 |= (1 << OCIE1A);
// Enable global interrupts
sei();
}
// Timer1 interrupt service routine (ISR)
ISR(TIMER1_COMPA_vect) {
millisCounter++;
// Your code here - keep it SHORT!
// DO NOT use Serial.print() in ISR
}
void loop() {
// Use millisCounter instead of millis()
static unsigned long lastPrint = 0;
if (millisCounter - lastPrint >= 1000) {
lastPrint = millisCounter;
Serial.println(millisCounter);
}
}
ESP32 Hardware Timer
// ESP32 has 4 hardware timers (0-3)
hw_timer_t *timer = NULL;
volatile uint32_t timerCounter = 0;
void IRAM_ATTR onTimer() {
timerCounter++;
// Keep ISR short and fast!
}
void setup() {
Serial.begin(115200);
// Initialize timer (timer number, prescaler, count up)
// ESP32 clock is 80 MHz
// Prescaler of 80 gives 1 MHz (1 tick = 1 us)
timer = timerBegin(0, 80, true);
// Attach interrupt function
timerAttachInterrupt(timer, &onTimer, true);
// Set alarm to trigger every 1ms (1000 us)
timerAlarmWrite(timer, 1000, true); // true = auto-reload
// Enable timer alarm
timerAlarmEnable(timer);
Serial.println("Timer initialized!");
}
void loop() {
static uint32_t lastCount = 0;
if (timerCounter - lastCount >= 1000) {
lastCount = timerCounter;
Serial.print("Timer count: ");
Serial.println(timerCounter);
}
}
// Ticker library (easier alternative)
#include <Ticker.h>
Ticker ticker;
volatile int count = 0;
void timerCallback() {
count++;
}
void setup() {
// Call timerCallback every 0.001 seconds (1ms)
ticker.attach(0.001, timerCallback);
}
STM32 HAL Timer
#include "stm32f4xx_hal.h"
TIM_HandleTypeDef htim2;
volatile uint32_t timerTicks = 0;
void Timer_Init(void) {
TIM_ClockConfigTypeDef sClockSourceConfig = {0};
TIM_MasterConfigTypeDef sMasterConfig = {0};
// TIM2 configuration
// APB1 clock = 84 MHz (for STM32F4)
// Prescaler = 8400 - 1 -> 10 kHz timer clock
// Period = 10 - 1 -> 1 kHz interrupt (1ms)
htim2.Instance = TIM2;
htim2.Init.Prescaler = 8400 - 1; // 84 MHz / 8400 = 10 kHz
htim2.Init.CounterMode = TIM_COUNTERMODE_UP;
htim2.Init.Period = 10 - 1; // 10 kHz / 10 = 1 kHz (1ms)
htim2.Init.ClockDivision = TIM_CLOCKDIVISION_DIV1;
htim2.Init.AutoReloadPreload = TIM_AUTORELOAD_PRELOAD_DISABLE;
HAL_TIM_Base_Init(&htim2);
sClockSourceConfig.ClockSource = TIM_CLOCKSOURCE_INTERNAL;
HAL_TIM_ConfigClockSource(&htim2, &sClockSourceConfig);
// Enable timer interrupt
HAL_TIM_Base_Start_IT(&htim2);
}
// Timer interrupt callback
void HAL_TIM_PeriodElapsedCallback(TIM_HandleTypeDef *htim) {
if (htim->Instance == TIM2) {
timerTicks++;
// Your periodic code here
}
}
// In main.c, enable interrupt in NVIC
void MX_TIM2_Init(void) {
Timer_Init();
HAL_NVIC_SetPriority(TIM2_IRQn, 0, 0);
HAL_NVIC_EnableIRQ(TIM2_IRQn);
}
PWM Generation with Timers
// Arduino PWM using Timer1
void setup() {
// Set pins as output
pinMode(9, OUTPUT); // OC1A
pinMode(10, OUTPUT); // OC1B
// Stop timer during configuration
TCCR1A = 0;
TCCR1B = 0;
// Fast PWM mode, ICR1 as TOP
// WGM13:0 = 14 (Fast PWM, TOP = ICR1)
TCCR1A = (1 << WGM11);
TCCR1B = (1 << WGM13) | (1 << WGM12);
// Non-inverting mode for both channels
TCCR1A |= (1 << COM1A1) | (1 << COM1B1);
// Prescaler = 8
TCCR1B |= (1 << CS11);
// Set TOP value for desired frequency
// PWM Frequency = F_CPU / (Prescaler * (1 + TOP))
// For 50 Hz: TOP = 16,000,000 / (8 * 50) - 1 = 39999
ICR1 = 39999; // 50 Hz
// Set duty cycle
OCR1A = 3000; // ~7.5% duty cycle on pin 9
OCR1B = 6000; // ~15% duty cycle on pin 10
}
// Servo control example
void setServoAngle(uint8_t angle) {
// Servo expects 1ms-2ms pulse every 20ms (50 Hz)
// 1ms = 0 degrees = 2000 counts
// 1.5ms = 90 degrees = 3000 counts
// 2ms = 180 degrees = 4000 counts
uint16_t pulse = map(angle, 0, 180, 2000, 4000);
OCR1A = pulse;
}
void loop() {
setServoAngle(0);
delay(1000);
setServoAngle(90);
delay(1000);
setServoAngle(180);
delay(1000);
}
Input Capture Mode
// Measure frequency of external signal on pin 8 (ICP1)
volatile unsigned long captureTime1 = 0;
volatile unsigned long captureTime2 = 0;
volatile boolean newCapture = false;
void setup() {
Serial.begin(9600);
// Configure Timer1 for input capture
TCCR1A = 0;
TCCR1B = 0;
// Prescaler = 64 (250 kHz timer, 4us resolution)
TCCR1B |= (1 << CS11) | (1 << CS10);
// Input Capture on rising edge
TCCR1B |= (1 << ICES1);
// Enable input capture interrupt
TIMSK1 |= (1 << ICIE1);
// Enable global interrupts
sei();
}
// Input capture interrupt
ISR(TIMER1_CAPT_vect) {
static boolean firstCapture = true;
if (firstCapture) {
captureTime1 = ICR1;
firstCapture = false;
} else {
captureTime2 = ICR1;
newCapture = true;
firstCapture = true;
}
}
void loop() {
if (newCapture) {
newCapture = false;
// Calculate period
unsigned long period = captureTime2 - captureTime1;
// Calculate frequency
// Timer runs at 250 kHz (4us per tick)
float frequency = 250000.0 / period;
Serial.print("Frequency: ");
Serial.print(frequency);
Serial.println(" Hz");
}
}
Common Applications
1. Precise Timing Without Delay
unsigned long previousMillis = 0;
const long interval = 1000;
void loop() {
unsigned long currentMillis = millis();
if (currentMillis - previousMillis >= interval) {
previousMillis = currentMillis;
// Execute every 1 second without blocking
toggleLED();
}
// Other code runs continuously
checkSensors();
processData();
}
2. Multiple Periodic Tasks
volatile uint32_t timerTicks = 0;
ISR(TIMER1_COMPA_vect) {
timerTicks++;
}
void loop() {
static uint32_t lastTask1 = 0;
static uint32_t lastTask2 = 0;
static uint32_t lastTask3 = 0;
// Task 1: Every 10ms
if (timerTicks - lastTask1 >= 10) {
lastTask1 = timerTicks;
readSensors();
}
// Task 2: Every 100ms
if (timerTicks - lastTask2 >= 100) {
lastTask2 = timerTicks;
updateDisplay();
}
// Task 3: Every 1000ms
if (timerTicks - lastTask3 >= 1000) {
lastTask3 = timerTicks;
sendData();
}
}
3. Watchdog Timer
#include <avr/wdt.h>
void setup() {
// Enable watchdog timer (8 second timeout)
wdt_enable(WDTO_8S);
}
void loop() {
// Do work
processData();
// Reset watchdog (prevent system reset)
wdt_reset();
// If code hangs, watchdog resets system after 8 seconds
}
4. Real-Time Clock (RTC)
// Using timer to maintain time
volatile uint32_t seconds = 0;
volatile uint16_t milliseconds = 0;
ISR(TIMER1_COMPA_vect) {
milliseconds++;
if (milliseconds >= 1000) {
milliseconds = 0;
seconds++;
}
}
void getTime(uint8_t *hours, uint8_t *minutes, uint8_t *secs) {
noInterrupts();
uint32_t totalSeconds = seconds;
interrupts();
*hours = (totalSeconds / 3600) % 24;
*minutes = (totalSeconds / 60) % 60;
*secs = totalSeconds % 60;
}
5. Debouncing Buttons
volatile uint32_t timerMs = 0;
ISR(TIMER1_COMPA_vect) {
timerMs++;
}
const int buttonPin = 2;
const int debounceTime = 50; // 50ms
bool readButtonDebounced() {
static uint32_t lastDebounceTime = 0;
static bool lastButtonState = HIGH;
static bool buttonState = HIGH;
bool reading = digitalRead(buttonPin);
if (reading != lastButtonState) {
lastDebounceTime = timerMs;
}
if ((timerMs - lastDebounceTime) > debounceTime) {
if (reading != buttonState) {
buttonState = reading;
return (buttonState == LOW); // Return true on button press
}
}
lastButtonState = reading;
return false;
}
Timer Prescaler Values
AVR (Arduino Uno/Nano/Mega)
| Prescaler | CS12 | CS11 | CS10 | Timer Frequency (16 MHz) |
|---|---|---|---|---|
| None | 0 | 0 | 0 | Stopped |
| 1 | 0 | 0 | 1 | 16 MHz |
| 8 | 0 | 1 | 0 | 2 MHz |
| 64 | 0 | 1 | 1 | 250 kHz |
| 256 | 1 | 0 | 0 | 62.5 kHz |
| 1024 | 1 | 0 | 1 | 15.625 kHz |
Best Practices
1. Keep ISRs Short and Fast
// BAD - Don't do this in ISR!
ISR(TIMER1_COMPA_vect) {
Serial.println("Timer fired"); // Serial is slow!
delay(100); // Blocks other interrupts!
float result = complexCalculation(); // Takes too long!
}
// GOOD - Set flags, process in main loop
volatile bool timerFlag = false;
ISR(TIMER1_COMPA_vect) {
timerFlag = true; // Just set a flag
}
void loop() {
if (timerFlag) {
timerFlag = false;
Serial.println("Timer fired"); // Do slow stuff here
processData();
}
}
2. Protect Shared Variables
volatile uint32_t sharedCounter = 0;
ISR(TIMER1_COMPA_vect) {
sharedCounter++;
}
void loop() {
// BAD - Not atomic! Can be corrupted if interrupt occurs mid-read
uint32_t localCopy = sharedCounter;
// GOOD - Disable interrupts during multi-byte read
noInterrupts();
uint32_t localCopy = sharedCounter;
interrupts();
Serial.println(localCopy);
}
3. Calculate Timer Values Correctly
// Formula for CTC mode:
// Compare Value = (F_CPU / (Prescaler * Desired_Frequency)) - 1
#define F_CPU 16000000UL
#define PRESCALER 64
#define DESIRED_HZ 1000 // 1 kHz
uint16_t compareValue = (F_CPU / (PRESCALER * DESIRED_HZ)) - 1;
// compareValue = (16000000 / (64 * 1000)) - 1 = 249
OCR1A = compareValue;
Common Issues and Debugging
Problem: Timer Interrupt Not Firing
Check:
- Global interrupts enabled (
sei()) - Specific timer interrupt enabled
- Prescaler and compare values calculated correctly
- Clock source selected
- ISR function name matches vector name
Problem: Inaccurate Timing
Causes:
- Wrong prescaler calculation
- Integer overflow in calculations
- CPU frequency mismatch
- Crystal tolerance
Problem: System Becomes Unresponsive
Causes:
- ISR takes too long (blocks other code)
- Interrupt firing too frequently
- Infinite loop in ISR
- Nested interrupts causing stack overflow
ELI10 (Explain Like I'm 10)
Imagine you have a special alarm clock that can do cool tricks:
-
Basic Timer: Counts from 0 to 100, then starts over. Like counting seconds!
-
Prescaler: Instead of counting every second, you count every 10 seconds. It's like skipping numbers to count slower.
-
Compare Match: When the count reaches a special number (like 50), the alarm rings! Then it keeps counting.
-
PWM: The alarm flashes a light on and off really fast. By changing how long it stays on vs off, you can make the light look dimmer or brighter!
-
Input Capture: You press a button, and the timer remembers what number it was at. Press again, and you can figure out how long between presses!
The coolest part? The timer runs by itself in the background - you don't have to watch it! It's like having a helper that tells you when it's time to do something, while you focus on other tasks.
Further Resources
- Arduino Timer Interrupts
- AVR Timers Tutorial
- ESP32 Timer Documentation
- STM32 Timer Cookbook
- Secrets of Arduino PWM
Watchdog Timers
A Watchdog Timer (WDT) is a hardware or software timer that is used to detect and recover from computer malfunctions. During normal operation, the system regularly resets the watchdog timer to prevent it from elapsing, or "timing out." If the system fails to reset the watchdog timer, it is assumed to be malfunctioning, and corrective actions are taken, such as resetting the system.
Key Concepts
- Timeout Period: The duration for which the watchdog timer runs before it times out. If the timer is not reset within this period, it triggers a system reset or other corrective actions.
- Reset Mechanism: The action taken when the watchdog timer times out. This is typically a system reset, but it can also include other actions like logging an error or entering a safe state.
- Feeding the Watchdog: The process of regularly resetting the watchdog timer to prevent it from timing out. This is also known as "kicking" or "patting" the watchdog.
Example Usage
- Embedded Systems: Watchdog timers are commonly used in embedded systems to ensure that the system can recover from unexpected failures. For example, if a microcontroller stops responding, the watchdog timer can reset it to restore normal operation.
- Safety-Critical Applications: In applications where safety is paramount, such as automotive or medical devices, watchdog timers help ensure that the system can recover from faults and continue to operate safely.
Conclusion
Watchdog timers are essential components in many systems, providing a mechanism to detect and recover from malfunctions. Understanding how to configure and use watchdog timers is crucial for developing reliable and resilient systems.
Power Management
Power management refers to the process of managing the power consumption of a device or system to optimize energy efficiency and prolong battery life. It is crucial in various applications, especially in portable devices like smartphones, laptops, and IoT devices.
Key Concepts
-
Sleep Modes: Many devices have different sleep modes that reduce power consumption when the device is not in active use. These modes can range from low-power states to complete shutdowns.
-
Dynamic Voltage and Frequency Scaling (DVFS): This technique adjusts the voltage and frequency of a processor based on the workload, allowing for reduced power consumption during low-demand periods.
-
Power Gating: This method involves shutting off power to certain components of a device when they are not in use, further conserving energy.
Applications
Power management techniques are widely used in:
- Mobile Devices: Extending battery life through efficient power usage.
- Data Centers: Reducing energy costs and improving cooling efficiency.
- Embedded Systems: Ensuring long operational life in battery-powered applications.
Conclusion
Effective power management is essential for enhancing the performance and longevity of electronic devices. By implementing various techniques, developers can create more energy-efficient systems that meet the demands of modern applications.
Debugging
Logic Analyzer
Saleae Logic 8
Networking
Comprehensive networking reference covering protocols, models, and networking fundamentals.
Networking Models
OSI Model
The 7-layer conceptual framework for network communication:
- Layer 7: Application
- Layer 6: Presentation
- Layer 5: Session
- Layer 4: Transport
- Layer 3: Network
- Layer 2: Data Link
- Layer 1: Physical
TCP/IP Model
The practical 4-layer model used in modern networks:
- Application Layer
- Transport Layer
- Internet Layer
- Network Access Layer
Core Protocols
IPv4 (Internet Protocol version 4)
- 32-bit addressing and packet format
- Address classes and private IP ranges
- Subnetting and CIDR notation
- Routing and fragmentation
- NAT (Network Address Translation)
- ICMP diagnostics and tools
IPv6 (Internet Protocol version 6)
- 128-bit addressing and packet format
- Address types (unicast, multicast, anycast)
- SLAAC and auto-configuration
- Neighbor Discovery Protocol (NDP)
- Extension headers
- ICMPv6 and transition mechanisms
TCP (Transmission Control Protocol)
- Reliable, connection-oriented communication
- 3-way handshake
- Flow control and congestion control
- Sequence numbers and acknowledgments
- Connection termination
UDP (User Datagram Protocol)
- Fast, connectionless communication
- Low overhead (8-byte header)
- No reliability guarantees
- Use cases: DNS, streaming, gaming, VoIP
- Socket programming examples
HTTP/HTTPS
- Web communication protocol
- Request methods (GET, POST, PUT, DELETE)
- Status codes
- Headers and caching
- Authentication and CORS
- REST API design
Name Resolution
DNS (Domain Name System)
- Translates domain names to IP addresses
- DNS hierarchy and record types
- Query and response messages
- DNS caching and TTL
- DNSSEC security
- DNS over HTTPS (DoH) and DNS over TLS (DoT)
- Public DNS servers
mDNS (Multicast DNS)
- Zero-configuration networking
- Local network name resolution (.local domain)
- Service discovery (DNS-SD)
- Avahi and Bonjour implementations
- Use cases: printers, file sharing, IoT devices
NAT Traversal
STUN (Session Traversal Utilities for NAT)
- Discovers public IP address and port
- Detects NAT type
- Enables peer-to-peer connections
- Used in WebRTC and VoIP
- Message format and examples
- Public STUN servers
TURN (Traversal Using Relays around NAT)
- Relays traffic when direct connection fails
- Fallback for restrictive NATs and firewalls
- Bandwidth-intensive
- Used with ICE in WebRTC
- Server setup with coturn
- Cost considerations
ICE (Interactive Connectivity Establishment)
- Framework for establishing peer-to-peer connections
- Combines STUN and TURN for NAT traversal
- Candidate gathering and connectivity checks
- Priority-based path selection
- Handles symmetric NAT and firewalls
- Used by WebRTC and VoIP
PCP (Port Control Protocol)
- Automatic port mapping and firewall control
- Successor to NAT-PMP with IPv6 support
- MAP and PEER opcodes for different use cases
- Works with multiple NATs in path
- Third-party mappings and explicit lifetimes
- Used by modern applications and IoT
NAT-PMP (NAT Port Mapping Protocol)
- Simple automatic port forwarding protocol
- Lightweight UDP-based (12-16 byte packets)
- IPv4 support with time-limited mappings
- Developed by Apple, widely deployed
- Gateway discovery and external IP detection
- Used by BitTorrent, VoIP, and gaming
Real-Time Communication
WebSocket
- Full-duplex bidirectional communication
- Low-latency persistent connections
- WebSocket handshake and frame format
- Client and server implementations
- Use cases: chat, live updates, gaming
- Authentication and security
- Heartbeat and reconnection strategies
WebRTC (Web Real-Time Communication)
- Browser-based peer-to-peer communication
- Video, audio, and data channels
- getUserMedia API and RTCPeerConnection
- Signaling and SDP offer/answer
- Media codecs and quality adaptation
- Security with mandatory encryption
- Simulcast and bandwidth management
Network Discovery
UPnP (Universal Plug and Play)
- Automatic device discovery
- Zero-configuration setup
- SSDP (Simple Service Discovery Protocol)
- Port forwarding (IGD)
- Security considerations
- Common device types
Security
Firewalls
- Packet filtering
- Stateful inspection
- Application layer firewalls
- Next-generation firewalls (NGFW)
- iptables, ufw, firewalld configurations
- NAT and port forwarding
- Firewall architectures (DMZ, screened subnet)
- Security best practices
Quick Reference
Protocol Port Numbers
| Protocol | Port | Transport | Purpose |
|---|---|---|---|
| HTTP | 80 | TCP | Web pages |
| HTTPS | 443 | TCP | Secure web |
| SSH | 22 | TCP | Secure shell |
| FTP | 20/21 | TCP | File transfer |
| DNS | 53 | UDP/TCP | Name resolution |
| DHCP | 67/68 | UDP | IP configuration |
| SMTP | 25 | TCP | Email sending |
| POP3 | 110 | TCP | Email retrieval |
| IMAP | 143 | TCP | Email access |
| STUN | 3478 | UDP | NAT discovery |
| SSDP | 1900 | UDP | UPnP discovery |
| mDNS | 5353 | UDP | Local DNS |
Common Network Tools
# Connectivity Testing
ping <host> # Test reachability
traceroute <host> # Trace route to host
# DNS Lookup
dig <domain> # DNS query
nslookup <domain> # DNS lookup
host <domain> # Simple DNS lookup
# Network Configuration
ifconfig # Network interface config (legacy)
ip addr show # Show IP addresses
ip route show # Show routing table
# Port Scanning
netstat -tuln # Show listening ports
ss -tuln # Socket statistics
nc -zv <host> <port> # Check if port is open
# Packet Capture
tcpdump -i any # Capture all traffic
tcpdump port 80 # Capture HTTP traffic
wireshark # GUI packet analyzer
# Service Discovery
avahi-browse -a # Browse mDNS services
upnpc -l # List UPnP devices
Private IP Address Ranges
10.0.0.0 - 10.255.255.255 (10/8 prefix)
172.16.0.0 - 172.31.255.255 (172.16/12 prefix)
192.168.0.0 - 192.168.255.255 (192.168/16 prefix)
Common Subnet Masks
| CIDR | Netmask | Hosts | Typical Use |
|---|---|---|---|
| /8 | 255.0.0.0 | 16,777,214 | Very large networks |
| /16 | 255.255.0.0 | 65,534 | Large networks |
| /24 | 255.255.255.0 | 254 | Small networks |
| /30 | 255.255.255.252 | 2 | Point-to-point links |
Protocol Relationships
Application Layer:
HTTP, FTP, SMTP, DNS, DHCP, SSH
|
v
Transport Layer:
TCP (reliable) or UDP (fast)
|
v
Network Layer:
IP (routing and addressing)
|
v
Data Link Layer:
Ethernet, WiFi (MAC addresses)
|
v
Physical Layer:
Cables, signals, physical media
Troubleshooting Flow
1. Physical Layer
- Cable connected?
- Link lights on?
-> Use: Visual inspection, ethtool
2. Data Link Layer
- MAC address correct?
- Switch working?
-> Use: arp -a, show mac address-table
3. Network Layer
- IP address assigned?
- Can ping gateway?
- Routing correct?
-> Use: ip addr, ping, traceroute
4. Transport Layer
- Port open?
- Firewall blocking?
- Service running?
-> Use: netstat, telnet, nc
5. Application Layer
- Service configured correctly?
- Authentication working?
- Application logs?
-> Use: curl, application-specific tools
Security Best Practices
Network Segmentation
- Separate networks by function (guest, IoT, corporate)
- Use VLANs for logical separation
- Firewall rules between segments
Access Control
- Implement firewall rules (default deny)
- Use strong authentication
- Enable logging and monitoring
- Regular security audits
Encryption
- Use HTTPS instead of HTTP
- Enable DNS over HTTPS/TLS
- Use VPN for remote access
- Encrypt sensitive traffic
Updates and Patches
- Keep firmware updated
- Patch vulnerabilities promptly
- Disable unused services
- Remove default credentials
Common Scenarios
Home Network Setup
- Router assigns private IPs (192.168.1.x)
- DHCP provides automatic configuration
- NAT translates private to public IP
- DNS resolves domain names (8.8.8.8)
- Devices use mDNS for local discovery
WebRTC Video Call
- STUN discovers public IP addresses
- ICE gathers connection candidates
- Signaling server exchanges candidates
- Direct P2P connection attempted
- TURN relay used if P2P fails
Smart Home Devices
- Devices announce via mDNS (device.local)
- UPnP enables automatic port forwarding
- Devices discover each other (SSDP)
- Control via local network
- Cloud connection for remote access
Further Learning
Online Resources
Books
- TCP/IP Illustrated by W. Richard Stevens
- Computer Networks by Andrew Tanenbaum
- Network Warrior by Gary Donahue
Practice
- Set up home lab with VirtualBox/VMware
- Use Packet Tracer for simulations
- Capture and analyze traffic with Wireshark
- Configure firewall rules
- Set up services (DNS, DHCP, web server)
OSI Model (Open Systems Interconnection)
Overview
The OSI Model is a conceptual framework that standardizes network communication into 7 layers. Each layer has specific responsibilities and communicates with the layers directly above and below it.
The 7 Layers
Layer 7: Application → User applications (HTTP, FTP, SMTP)
Layer 6: Presentation → Data format, encryption (SSL/TLS)
Layer 5: Session → Session management
Layer 4: Transport → End-to-end delivery (TCP, UDP)
Layer 3: Network → Routing, IP addressing
Layer 2: Data Link → MAC addressing, switches
Layer 1: Physical → Physical media, cables, signals
Memory Aids
Top to Bottom: All People Seem To Need Data Processing
Bottom to Top: Please Do Not Throw Sausage Pizza Away
Layer 1: Physical Layer
Purpose
Transmits raw bits (0s and 1s) over physical media.
Responsibilities
- Physical connection between devices
- Bit transmission and reception
- Voltage levels, timing, data rates
- Cable specifications
- Signal encoding
Components
- Cables: Ethernet (Cat5e, Cat6), Fiber optic, Coaxial
- Hubs: Repeat signals to all ports
- Repeaters: Amplify signals
- Network Interface Cards (NICs)
Encoding Examples
Manchester Encoding (Ethernet):
0: High-to-low transition
1: Low-to-high transition
1 0 1 1 0
_|‾|_ _‾|_ _|‾|_ _|‾|_ _‾|_
Physical Media Types
| Medium | Speed | Distance | Use Case |
|---|---|---|---|
| Cat5e | 1 Gbps | 100m | Ethernet LAN |
| Cat6 | 10 Gbps | 55m | High-speed LAN |
| Fiber (MM) | 10 Gbps | 550m | Building backbone |
| Fiber (SM) | 100 Gbps | 40km+ | Long distance |
| WiFi | 1-10 Gbps | 100m | Wireless LAN |
Example: Bit Transmission
Computer A wants to send "Hello" (binary: 01001000...)
Physical Layer:
1. Convert bits to electrical signals
2. Transmit over cable at defined voltage levels
High voltage (2.5V) = 1
Low voltage (0V) = 0
3. Receiver samples signals and reconstructs bits
Layer 2: Data Link Layer
Purpose
Provides node-to-node data transfer with error detection.
Responsibilities
- MAC (Media Access Control) addressing
- Frame formatting
- Error detection (CRC)
- Flow control
- Media access control
Sub-layers
- LLC (Logical Link Control): Interface to Network Layer
- MAC (Media Access Control): Access to physical medium
Components
- Switches: Forward frames based on MAC addresses
- Bridges: Connect network segments
- Network Interface Cards: Hardware MAC addresses
Ethernet Frame Format
Preamble | SFD | Dest MAC | Src MAC | Type | Data | FCS
7B | 1B | 6B | 6B | 2B | 46-1500B | 4B
Preamble: 10101010... (synchronization)
SFD: Start Frame Delimiter (10101011)
Dest MAC: Destination hardware address
Src MAC: Source hardware address
Type: Protocol type (0x0800 = IPv4, 0x86DD = IPv6)
Data: Payload (46-1500 bytes)
FCS: Frame Check Sequence (CRC-32)
MAC Address Format
AA:BB:CC:DD:EE:FF (48 bits / 6 bytes)
AA:BB:CC - OUI (Organizationally Unique Identifier)
Vendor identification
DD:EE:FF - Device identifier
Example: 00:1A:2B:3C:4D:5E
Example: Frame Forwarding
Switch MAC Address Table:
Port 1: AA:AA:AA:AA:AA:AA
Port 2: BB:BB:BB:BB:BB:BB
Port 3: CC:CC:CC:CC:CC:CC
Frame arrives on Port 1:
Dest MAC: BB:BB:BB:BB:BB:BB
Switch looks up BB:BB:BB:BB:BB:BB → Port 2
Forwards frame only to Port 2
ARP (Address Resolution Protocol)
Maps IP addresses to MAC addresses:
Host A wants to send to 192.168.1.5
1. Check ARP cache
2. If not found, broadcast ARP request:
"Who has 192.168.1.5? Tell 192.168.1.10"
3. Host with 192.168.1.5 replies:
"192.168.1.5 is at AA:BB:CC:DD:EE:FF"
4. Cache the mapping
5. Send frame to AA:BB:CC:DD:EE:FF
Layer 3: Network Layer
Purpose
Routes packets across networks from source to destination.
Responsibilities
- Logical addressing (IP addresses)
- Routing
- Packet forwarding
- Fragmentation and reassembly
- Error handling (ICMP)
Components
- Routers: Forward packets between networks
- Layer 3 Switches: Routing at hardware speed
Protocols
- IP (IPv4, IPv6): Internet Protocol
- ICMP: Error reporting and diagnostics
- OSPF, BGP, RIP: Routing protocols
Example: Routing Decision
Router receives packet for 10.1.2.5
Routing Table:
10.1.0.0/16 via 192.168.1.1
10.1.2.0/24 via 192.168.1.2
0.0.0.0/0 via 192.168.1.254 (default)
Longest prefix match: 10.1.2.0/24
Forward to 192.168.1.2
Packet Journey Example
PC1 (192.168.1.10) → Server (10.0.0.5)
Layer 3 decisions at each hop:
1. PC1: Not local subnet → Send to gateway (192.168.1.1)
2. Router1: Check route → Forward to Router2 (10.0.0.1)
3. Router2: Destination is local → Send to 10.0.0.5
Layer 4: Transport Layer
Purpose
Provides end-to-end communication and reliability.
Responsibilities
- Segmentation and reassembly
- Port addressing
- Connection management
- Flow control
- Error recovery
- Multiplexing
Protocols
- TCP: Reliable, connection-oriented
- UDP: Unreliable, connectionless
Port Numbers
Source Port: Identifies sending application
Dest Port: Identifies receiving application
Well-known ports (0-1023):
80 - HTTP
443 - HTTPS
22 - SSH
53 - DNS
Registered ports (1024-49151):
3306 - MySQL
5432 - PostgreSQL
Dynamic ports (49152-65535):
Ephemeral ports for client connections
Example: TCP Connection
Client (192.168.1.10:5000) → Server (10.0.0.5:80)
Layer 4 provides:
1. Connection establishment (3-way handshake)
2. Reliable delivery (ACKs, retransmission)
3. Ordering (sequence numbers)
4. Flow control (window size)
5. Connection termination (4-way close)
Multiplexing Example
Web browser opens multiple connections:
Tab 1: 192.168.1.10:5000 → google.com:443
Tab 2: 192.168.1.10:5001 → github.com:443
Tab 3: 192.168.1.10:5002 → stackoverflow.com:443
Transport layer demultiplexes based on port
Layer 5: Session Layer
Purpose
Manages sessions (connections) between applications.
Responsibilities
- Session establishment, maintenance, termination
- Dialog control (half-duplex, full-duplex)
- Synchronization
- Token management
Functions
- Authentication: Verify user credentials
- Authorization: Check permissions
- Session restoration: Resume interrupted sessions
Examples
RPC (Remote Procedure Call):
Client Server
| |
| Session established |
|<------------------------------>|
| Call remote procedure |
|------------------------------->|
| Maintain session state |
|<------------------------------>|
| Session terminated |
NetBIOS:
- Session management for file/printer sharing
- Name registration and resolution
Synchronization Points
File Transfer with checkpoints:
0KB -------- 100KB -------- 200KB -------- 300KB
^ ^ ^ ^
Sync 1 Sync 2 Sync 3 Complete
If failure at 250KB:
Resume from Sync 2 (200KB)
Layer 6: Presentation Layer
Purpose
Translates data between application and network formats.
Responsibilities
- Data format translation
- Encryption/decryption
- Compression/decompression
- Character encoding
Functions
1. Data Translation:
ASCII ↔ EBCDIC
Big-endian ↔ Little-endian
JSON ↔ XML ↔ Binary
2. Encryption:
Plaintext: "Hello World"
↓
SSL/TLS Encryption
↓
Ciphertext: "3k#9$mL..."
3. Compression:
Original: 1000 bytes
↓
GZIP Compression
↓
Compressed: 300 bytes
Examples
SSL/TLS:
Application sends: "GET / HTTP/1.1"
↓
Presentation Layer: Encrypts with TLS
↓
Transport Layer: Sends encrypted data
Image Formats:
- JPEG, PNG, GIF (compressed formats)
- Format conversion for display
Character Encoding:
String "Hello" in different encodings:
ASCII: 48 65 6C 6C 6F
UTF-8: 48 65 6C 6C 6F
UTF-16: 00 48 00 65 00 6C 00 6C 00 6F
Layer 7: Application Layer
Purpose
Provides network services directly to user applications.
Responsibilities
- Application-level protocols
- User authentication
- Data representation
- Resource sharing
Common Protocols
| Protocol | Port | Purpose |
|---|---|---|
| HTTP/HTTPS | 80/443 | Web browsing |
| FTP | 20/21 | File transfer |
| SMTP | 25 | Email sending |
| POP3 | 110 | Email retrieval |
| IMAP | 143 | Email access |
| DNS | 53 | Name resolution |
| DHCP | 67/68 | IP configuration |
| SSH | 22 | Secure shell |
| Telnet | 23 | Remote terminal |
| SNMP | 161 | Network management |
Example: HTTP Request
User clicks link in browser
Application Layer (HTTP):
GET /index.html HTTP/1.1
Host: example.com
Presentation Layer:
Encrypt with TLS (HTTPS)
Session Layer:
Maintain HTTPS session
Transport Layer:
TCP connection to port 443
Network Layer:
Route to example.com's IP
Data Link Layer:
Frame with MAC address
Physical Layer:
Transmit bits on wire
Data Encapsulation
Encapsulation Process (Sending)
Layer 7: User Data
↓
Layer 4: [TCP Header][Data] → Segment
↓
Layer 3: [IP Header][TCP Header][Data] → Packet
↓
Layer 2: [Eth Header][IP Header][TCP][Data][Eth Trailer] → Frame
↓
Layer 1: 010101110101... → Bits
Decapsulation Process (Receiving)
Layer 1: Receive bits
↓
Layer 2: Remove Ethernet header/trailer → Frame
↓
Layer 3: Remove IP header → Packet
↓
Layer 4: Remove TCP header → Segment
↓
Layer 7: Deliver data to application
PDU (Protocol Data Unit) Names
Layer 7-5: Data
Layer 4: Segment (TCP) / Datagram (UDP)
Layer 3: Packet
Layer 2: Frame
Layer 1: Bits
Complete Communication Example
Sending Email via SMTP
Layer 7 (Application):
- SMTP client: "MAIL FROM: alice@example.com"
- Creates email message
Layer 6 (Presentation):
- Encode as ASCII
- Compress if needed
- Encrypt with TLS
Layer 5 (Session):
- Establish SMTP session
- Authenticate with mail server
Layer 4 (Transport):
- TCP connection to port 25
- Segment data
- Add source/dest ports
Layer 3 (Network):
- Add IP header
- Source: 192.168.1.10
- Dest: 10.0.0.5 (mail server)
- Route to destination
Layer 2 (Data Link):
- Add MAC addresses
- Create Ethernet frame
- Error checking (CRC)
Layer 1 (Physical):
- Convert to electrical signals
- Transmit on cable
Troubleshooting by Layer
Layer 1 (Physical) Issues
Symptoms: No connectivity, link down
Check:
- Cable plugged in?
- Cable damaged?
- Port lights on?
- Power on device?
Tools: Visual inspection, cable tester
Layer 2 (Data Link) Issues
Symptoms: Can't reach other devices on LAN
Check:
- MAC address conflicts?
- Switch port errors?
- VLAN configuration?
- ARP table correct?
Tools: arp -a, show mac address-table
Layer 3 (Network) Issues
Symptoms: Can't reach remote networks
Check:
- IP address correct?
- Subnet mask correct?
- Gateway configured?
- Routing table?
Tools: ping, traceroute, ip route
Layer 4 (Transport) Issues
Symptoms: Can't connect to specific service
Check:
- Port open?
- Firewall blocking?
- Service running?
- TCP handshake succeeds?
Tools: telnet, nc (netcat), netstat
Layer 7 (Application) Issues
Symptoms: Service accessible but not working
Check:
- Application configuration?
- Authentication failing?
- Correct protocol version?
- Application logs?
Tools: curl, application-specific tools
OSI vs Real Protocols
Where Real Protocols Fit
OSI Layer Protocol Examples
---------------------------------------
7 - Application HTTP, FTP, SMTP, DNS
6 - Presentation SSL/TLS, JPEG, MPEG
5 - Session NetBIOS, RPC
4 - Transport TCP, UDP
3 - Network IP, ICMP, OSPF, BGP
2 - Data Link Ethernet, WiFi, PPP
1 - Physical 10BASE-T, 100BASE-TX
TCP/IP Model Mapping
OSI Model TCP/IP Model
-----------------------------------------
7 - Application
6 - Presentation → Application
5 - Session
4 - Transport → Transport
3 - Network → Internet
2 - Data Link
1 - Physical → Network Access
Benefits of Layered Approach
1. Modularity
Change one layer without affecting others
Example: Switch from WiFi to Ethernet
(Only Layer 1/2 change, others unaffected)
2. Standardization
Multiple vendors can interoperate
Example: Any HTTP client can talk to any HTTP server
3. Troubleshooting
Systematic approach from bottom up:
1. Physical: Cable OK?
2. Data Link: Connected to switch?
3. Network: Can ping gateway?
4. Transport: Port open?
5. Application: Service running?
4. Development
Developers focus on their layer
Example: Web developer uses HTTP (Layer 7)
Doesn't need to know about TCP internals
ELI10
The OSI Model is like sending a letter through the mail:
Layer 7 (Application): You write a letter
- What you want to say
Layer 6 (Presentation): You format it nicely
- Maybe encrypt it (secret code)
- Compress it (make it smaller)
Layer 5 (Session): You start a conversation
- "Dear John" and "Sincerely, Alice"
Layer 4 (Transport): You put it in envelopes
- Split into pages if too long
- Number the pages so they can be reassembled
Layer 3 (Network): You write the address
- Where it's going
- Where it's from
Layer 2 (Data Link): Post office processes it
- Local post office routing
- Check if envelope is damaged
Layer 1 (Physical): The mail truck
- Physical delivery
- Roads, trucks, planes
Each layer does its job without worrying about the others!
Further Resources
TCP/IP Model
Overview
The TCP/IP Model (also called Internet Protocol Suite) is a practical, 4-layer networking model that describes how data is transmitted over the internet. Unlike the OSI Model, which is theoretical, TCP/IP is the actual model used in modern networks.
TCP/IP vs OSI Model
OSI Model (7 Layers) TCP/IP Model (4 Layers)
---------------------------------------------------
7. Application
6. Presentation → 4. Application
5. Session
4. Transport → 3. Transport
3. Network → 2. Internet
2. Data Link
1. Physical → 1. Network Access
The 4 Layers
Layer 1: Network Access (Link Layer)
Purpose: Physical transmission of data on a network
Combines:
- OSI Physical Layer (Layer 1)
- OSI Data Link Layer (Layer 2)
Responsibilities:
- Physical addressing (MAC)
- Media access control
- Frame formatting
- Error detection
- Physical transmission
Protocols/Technologies:
- Ethernet (IEEE 802.3)
- WiFi (IEEE 802.11)
- PPP (Point-to-Point Protocol)
- ARP (Address Resolution Protocol)
- RARP (Reverse ARP)
Example:
Data from Internet Layer
↓
Add Ethernet Header:
[Dest MAC: AA:BB:CC:DD:EE:FF]
[Src MAC: 11:22:33:44:55:66]
[Type: 0x0800 (IPv4)]
[Data]
[CRC Checksum]
↓
Convert to bits and transmit
Layer 2: Internet Layer
Purpose: Routes packets across networks
Equivalent to: OSI Network Layer (Layer 3)
Responsibilities:
- Logical addressing (IP)
- Routing between networks
- Packet forwarding
- Fragmentation and reassembly
- Error reporting
Key Protocols:
| Protocol | Purpose | RFC |
|---|---|---|
| IP | Internet Protocol (IPv4, IPv6) | RFC 791, 8200 |
| ICMP | Error reporting, diagnostics | RFC 792 |
| IGMP | Multicast group management | RFC 1112 |
| IPsec | Security (encryption, authentication) | RFC 4301 |
Example: Packet Routing
Source: 192.168.1.10 → Destination: 10.0.0.5
IP Layer adds header:
[Version: 4]
[TTL: 64]
[Protocol: 6 (TCP)]
[Source IP: 192.168.1.10]
[Dest IP: 10.0.0.5]
[Data]
Router at each hop:
1. Decrements TTL
2. Checks routing table
3. Forwards to next hop
4. Recalculates checksum
Layer 3: Transport Layer
Purpose: End-to-end communication between applications
Equivalent to: OSI Transport Layer (Layer 4)
Responsibilities:
- Port-based multiplexing
- Connection management
- Reliability (for TCP)
- Flow control
- Error recovery
Key Protocols:
TCP (Transmission Control Protocol)
Characteristics:
- Connection-oriented
- Reliable delivery
- Ordered delivery
- Flow control
- Congestion control
TCP Segment:
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Port | Destination Port |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Sequence Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Acknowledgment Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Offset| Res | Flags | Window |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Checksum | Urgent Pointer |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Data |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
TCP Connection (3-Way Handshake):
Client Server
| |
| SYN (seq=100) |
|------------------------------->|
| |
| SYN-ACK (seq=200, ack=101) |
|<-------------------------------|
| |
| ACK (seq=101, ack=201) |
|------------------------------->|
| |
| Connection Established |
UDP (User Datagram Protocol)
Characteristics:
- Connectionless
- Unreliable
- No ordering guarantee
- Low overhead
- Fast
UDP Datagram:
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Port | Destination Port |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Length | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Data |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
UDP Communication:
Client Server
| |
| UDP Datagram |
|------------------------------->|
| |
| No acknowledgment |
Fire and forget!
Layer 4: Application Layer
Purpose: Provides network services to applications
Combines:
- OSI Application Layer (Layer 7)
- OSI Presentation Layer (Layer 6)
- OSI Session Layer (Layer 5)
Responsibilities:
- Application-specific protocols
- Data formatting
- Session management
- User authentication
Common Protocols:
| Protocol | Port | Transport | Purpose |
|---|---|---|---|
| HTTP | 80 | TCP | Web pages |
| HTTPS | 443 | TCP | Secure web |
| FTP | 20/21 | TCP | File transfer |
| SFTP | 22 | TCP | Secure file transfer |
| SSH | 22 | TCP | Secure shell |
| Telnet | 23 | TCP | Remote terminal |
| SMTP | 25 | TCP | Send email |
| DNS | 53 | UDP/TCP | Name resolution |
| DHCP | 67/68 | UDP | IP configuration |
| TFTP | 69 | UDP | Simple file transfer |
| HTTP/3 | 443 | UDP (QUIC) | Modern web |
| NTP | 123 | UDP | Time sync |
| SNMP | 161/162 | UDP | Network management |
| POP3 | 110 | TCP | Email retrieval |
| IMAP | 143 | TCP | Email access |
| RDP | 3389 | TCP | Remote desktop |
Example: HTTP Request
Application Layer creates:
GET /index.html HTTP/1.1
Host: www.example.com
User-Agent: Mozilla/5.0
Accept: text/html
↓
Transport Layer (TCP):
- Add TCP header
- Source port: 54321
- Dest port: 80
- Establish connection
↓
Internet Layer (IP):
- Add IP header
- Resolve www.example.com to IP
- Source: 192.168.1.10
- Dest: 93.184.216.34
↓
Network Access Layer:
- ARP for next hop MAC
- Add Ethernet frame
- Transmit bits
Data Encapsulation in TCP/IP
Sending Data
Step 1: Application creates data
"GET /index.html HTTP/1.1\r\n..."
Step 2: Transport Layer adds header
[TCP Header][HTTP Request] → TCP Segment
Step 3: Internet Layer adds header
[IP Header][TCP Header][HTTP Request] → IP Packet
Step 4: Network Access adds header/trailer
[Eth Header][IP][TCP][HTTP][Eth Trailer] → Ethernet Frame
Step 5: Convert to bits
01001000110101... → Bits on wire
Receiving Data
Step 1: Receive bits, extract frame
[Eth Header][IP][TCP][HTTP][Eth Trailer]
Step 2: Check Ethernet checksum, remove header
[IP Header][TCP Header][HTTP Request]
Step 3: Check IP checksum, route to TCP
[TCP Header][HTTP Request]
Step 4: Process TCP segment, reassemble
"GET /index.html HTTP/1.1\r\n..."
Step 5: Deliver to HTTP server application
Complete Communication Example
Browsing a Website
User types: http://www.example.com
=== Application Layer ===
1. Browser resolves domain name
DNS Query: "What's the IP of www.example.com?"
DNS Response: "93.184.216.34"
2. Browser creates HTTP request
GET / HTTP/1.1
Host: www.example.com
=== Transport Layer ===
3. TCP connection to port 80
- 3-way handshake
- Establish connection
- Segment data if needed
=== Internet Layer ===
4. Create IP packet
- Source: 192.168.1.10
- Dest: 93.184.216.34
- Protocol: TCP (6)
- Add to routing queue
5. Routing
- Check routing table
- Forward to default gateway
- Each router forwards packet
=== Network Access Layer ===
6. Resolve next hop MAC (ARP)
- "Who has 192.168.1.1?"
- "192.168.1.1 is at AA:BB:CC:DD:EE:FF"
7. Create Ethernet frame
- Dest MAC: Gateway's MAC
- Src MAC: PC's MAC
- Add checksum
8. Transmit on physical medium
- Convert to electrical signals
- Send on Ethernet cable
=== Server Processes Request ===
9. Server receives, decapsulates
10. HTTP server processes request
11. Sends response back
=== Browser Receives Response ===
12. Decapsulate all layers
13. Browser renders HTML
Protocol Interactions
DNS Resolution
Application: DNS client
Transport: UDP port 53
Internet: IP packet to DNS server
Network Access: Ethernet to gateway
Query: www.example.com → 93.184.216.34
Email Sending (SMTP)
Application: SMTP client (port 25)
Transport: TCP connection
Internet: Route to mail server IP
Network Access: Frame to gateway
MAIL FROM: alice@example.com
RCPT TO: bob@example.com
DATA
Subject: Hello
...
File Transfer (FTP)
Application: FTP client
Transport:
- Control: TCP port 21
- Data: TCP port 20
Internet: IP to FTP server
Network Access: Ethernet frames
Commands on port 21:
USER alice
PASS secret123
RETR file.txt
Data transfer on port 20
Port Numbers
Well-Known Ports (0-1023)
Require root/admin privileges:
20/21 FTP
22 SSH
23 Telnet
25 SMTP
53 DNS
67/68 DHCP
80 HTTP
110 POP3
143 IMAP
443 HTTPS
Registered Ports (1024-49151)
For specific services:
3306 MySQL
5432 PostgreSQL
6379 Redis
8080 HTTP alternate
8443 HTTPS alternate
27017 MongoDB
Dynamic/Private Ports (49152-65535)
Used by clients for outgoing connections:
Client opens connection:
Source: 192.168.1.10:54321 (dynamic)
Dest: 93.184.216.34:80 (well-known)
TCP/IP Configuration
Manual Configuration
# Set IP address
sudo ip addr add 192.168.1.100/24 dev eth0
# Set default gateway
sudo ip route add default via 192.168.1.1
# Set DNS server
echo "nameserver 8.8.8.8" >> /etc/resolv.conf
DHCP (Dynamic Host Configuration Protocol)
Client DHCP Server
| |
| DHCP Discover (broadcast) |
|------------------------------->|
| |
| DHCP Offer |
|<-------------------------------|
| IP: 192.168.1.100 |
| Netmask: 255.255.255.0 |
| Gateway: 192.168.1.1 |
| DNS: 8.8.8.8 |
| |
| DHCP Request |
|------------------------------->|
| |
| DHCP ACK |
|<-------------------------------|
| |
Client now configured with:
IP Address: 192.168.1.100
Subnet Mask: 255.255.255.0
Default Gateway: 192.168.1.1
DNS Server: 8.8.8.8
Lease Time: 24 hours
TCP/IP Troubleshooting
Layer 1: Network Access
# Check physical connection
ip link show
ethtool eth0
# Check link status
cat /sys/class/net/eth0/carrier
Symptoms: No link light, cable unplugged
Layer 2: Network Access (Data Link)
# Check ARP table
arp -a
ip neigh show
# Check switch port
show mac address-table
Symptoms: Can't reach local devices
Layer 3: Internet
# Check IP configuration
ip addr show
ifconfig
# Test gateway reachability
ping 192.168.1.1
# Check routing
ip route show
traceroute 8.8.8.8
Symptoms: No internet, can't reach remote hosts
Layer 4: Transport
# Check listening ports
netstat -tuln
ss -tuln
# Test port connectivity
telnet example.com 80
nc -zv example.com 80
# Check firewall
iptables -L
ufw status
Symptoms: Connection refused, timeout
Layer 5: Application
# Test HTTP
curl -v http://example.com
# Test DNS
dig example.com
nslookup example.com
# Test SMTP
telnet mail.example.com 25
Symptoms: Service not responding correctly
TCP/IP Security
Common Vulnerabilities
1. IP Spoofing
Attacker sends packets with fake source IP
Victim: 10.0.0.5
Attacker pretends to be: 10.0.0.5
2. TCP SYN Flood
Attacker sends many SYN packets
Server waits for ACK (never comes)
Server resources exhausted
3. Man-in-the-Middle
Attacker intercepts traffic between client and server
Can read or modify data
Security Protocols
IPsec (Internet Protocol Security)
Provides:
- Authentication Header (AH)
- Encapsulating Security Payload (ESP)
- Encryption and authentication
Used for VPNs
TLS/SSL (Transport Layer Security)
Encrypts application data
Provides:
- Confidentiality (encryption)
- Integrity (tampering detection)
- Authentication (certificates)
Used for HTTPS, SMTPS, etc.
TCP/IP Performance Tuning
TCP Window Scaling
Default window: 65,535 bytes
With scaling: Up to 1 GB
Improves throughput on high-latency links
TCP Congestion Control Algorithms
- Tahoe: Original algorithm
- Reno: Fast recovery
- CUBIC: Default in Linux (good for high-speed)
- BBR: Google's algorithm (optimal bandwidth)
Monitoring TCP Performance
# TCP statistics
netstat -s
ss -s
# Per-connection statistics
ss -ti
# Packet captures
tcpdump -i any -w capture.pcap
ELI10
TCP/IP is how computers talk to each other on the internet:
Layer 1: Network Access (The Road)
- Physical cables and WiFi
- Like the road system for mail delivery
Layer 2: Internet (The Address)
- IP addresses (like street addresses)
- Routers (like post offices) send packets the right way
Layer 3: Transport (The Envelope)
- TCP: Certified mail (guaranteed delivery, in order)
- UDP: Postcard (fast, but might get lost)
Layer 4: Application (The Message)
- The actual letter you're sending
- HTTP for websites, SMTP for email, etc.
Example: Loading a website
- You type www.google.com
- DNS finds Google's address (142.250.185.78)
- TCP opens a connection (handshake)
- HTTP sends "Give me the homepage"
- Routers deliver packets to Google
- Google sends back the HTML
- Your browser shows the page
Each layer does its job without worrying about the others!
Further Resources
- RFC 1122 - Requirements for Internet Hosts
- RFC 1123 - Application and Support
- TCP/IP Illustrated by W. Richard Stevens
- TCP/IP Guide
IP (Internet Protocol)
Overview
IP (Internet Protocol) is the network layer protocol responsible for addressing and routing packets across networks. It provides the addressing scheme that allows devices to find each other on the internet.
IP Versions
| Feature | IPv4 | IPv6 |
|---|---|---|
| Address Size | 32 bits | 128 bits |
| Address Format | Decimal (192.168.1.1) | Hexadecimal (2001:db8::1) |
| Total Addresses | ~4.3 billion | 340 undecillion |
| Header Size | 20-60 bytes | 40 bytes (fixed) |
| Checksum | Yes | No (delegated to link layer) |
| Fragmentation | By routers | Source only |
| Broadcast | Yes | No (uses multicast) |
| Configuration | Manual or DHCP | SLAAC or DHCPv6 |
| IPSec | Optional | Mandatory |
IPv4 Packet Format
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|Version| IHL |Type of Service| Total Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Identification |Flags| Fragment Offset |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Time to Live | Protocol | Header Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Destination Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Options (if IHL > 5) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Data |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
IPv4 Header Fields
- Version (4 bits): IP version (4 for IPv4)
- IHL (4 bits): Internet Header Length (5-15, in 32-bit words)
- Type of Service (8 bits): QoS, priority
- Total Length (16 bits): Entire packet size (max 65,535 bytes)
- Identification (16 bits): Fragment identification
- Flags (3 bits):
- Bit 0: Reserved (must be 0)
- Bit 1: Don't Fragment (DF)
- Bit 2: More Fragments (MF)
- Fragment Offset (13 bits): Position of fragment
- Time to Live (TTL) (8 bits): Max hops (decremented at each router)
- Protocol (8 bits): Upper layer protocol (6=TCP, 17=UDP, 1=ICMP)
- Header Checksum (16 bits): Error detection for header
- Source Address (32 bits): Sender IP address
- Destination Address (32 bits): Receiver IP address
- Options (variable): Rarely used (security, timestamp, etc.)
IPv4 Address Classes
Traditional Class System (Obsolete, replaced by CIDR)
Class A: 0.0.0.0 to 127.255.255.255 /8 (16 million hosts)
Network: 8 bits, Host: 24 bits
Class B: 128.0.0.0 to 191.255.255.255 /16 (65,536 hosts)
Network: 16 bits, Host: 16 bits
Class C: 192.0.0.0 to 223.255.255.255 /24 (254 hosts)
Network: 24 bits, Host: 8 bits
Class D: 224.0.0.0 to 239.255.255.255 (Multicast)
Class E: 240.0.0.0 to 255.255.255.255 (Reserved)
Private IP Address Ranges
10.0.0.0 - 10.255.255.255 (10/8 prefix)
172.16.0.0 - 172.31.255.255 (172.16/12 prefix)
192.168.0.0 - 192.168.255.255 (192.168/16 prefix)
Used in LANs, not routed on internet (NAT required)
Special IPv4 Addresses
0.0.0.0/8 - Current network (only valid as source)
127.0.0.0/8 - Loopback (127.0.0.1 = localhost)
169.254.0.0/16 - Link-local (APIPA, auto-config failed)
192.0.2.0/24 - Documentation/examples (TEST-NET-1)
198.18.0.0/15 - Benchmark testing
224.0.0.0/4 - Multicast
255.255.255.255 - Limited broadcast
CIDR (Classless Inter-Domain Routing)
CIDR Notation
192.168.1.0/24
^^
Number of network bits
/24 = 255.255.255.0 netmask
24 bits for network, 8 bits for hosts
2^8 - 2 = 254 usable host addresses
Common Subnet Masks
| CIDR | Netmask | Hosts | Use Case |
|---|---|---|---|
| /8 | 255.0.0.0 | 16,777,214 | Huge networks |
| /16 | 255.255.0.0 | 65,534 | Large networks |
| /24 | 255.255.255.0 | 254 | Small networks |
| /25 | 255.255.255.128 | 126 | Subnet split |
| /26 | 255.255.255.192 | 62 | Small subnet |
| /27 | 255.255.255.224 | 30 | Very small |
| /30 | 255.255.255.252 | 2 | Point-to-point |
| /32 | 255.255.255.255 | 1 | Single host |
Subnet Calculation Example
Network: 192.168.1.0/24
Network Address: 192.168.1.0
First Usable: 192.168.1.1
Last Usable: 192.168.1.254
Broadcast Address: 192.168.1.255
Total Hosts: 256
Usable Hosts: 254
Subnetting Example
Original: 192.168.1.0/24 (254 hosts)
Split into 4 subnets (/26 each):
Subnet 1: 192.168.1.0/26 (192.168.1.1 - 192.168.1.62)
Subnet 2: 192.168.1.64/26 (192.168.1.65 - 192.168.1.126)
Subnet 3: 192.168.1.128/26 (192.168.1.129 - 192.168.1.190)
Subnet 4: 192.168.1.192/26 (192.168.1.193 - 192.168.1.254)
Each subnet: 62 usable hosts
IPv6 Packet Format
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|Version| Traffic Class | Flow Label |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Payload Length | Next Header | Hop Limit |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
+ Source Address +
| (128 bits) |
+ +
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
+ Destination Address +
| (128 bits) |
+ +
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
IPv6 Header Fields (40 bytes fixed)
- Version (4 bits): IP version (6)
- Traffic Class (8 bits): QoS, similar to ToS in IPv4
- Flow Label (20 bits): QoS flow identification
- Payload Length (16 bits): Data length (excluding header)
- Next Header (8 bits): Protocol type (like IPv4 Protocol field)
- Hop Limit (8 bits): Like IPv4 TTL
- Source Address (128 bits)
- Destination Address (128 bits)
IPv6 Address Format
Full Representation
2001:0db8:0000:0042:0000:8a2e:0370:7334
Compressed Representation
# Remove leading zeros
2001:db8:0:42:0:8a2e:370:7334
# Replace consecutive zeros with ::
2001:db8:0:42::8a2e:370:7334
# Loopback
::1 (equivalent to 0:0:0:0:0:0:0:1)
# Unspecified
:: (equivalent to 0:0:0:0:0:0:0:0)
IPv6 Address Types
| Type | Prefix | Example | Purpose |
|---|---|---|---|
| Global Unicast | 2000::/3 | 2001:db8::1 | Internet routing |
| Link-Local | fe80::/10 | fe80::1 | Local network only |
| Unique Local | fc00::/7 | fd00::1 | Private (like RFC 1918) |
| Multicast | ff00::/8 | ff02::1 | One-to-many |
| Loopback | ::1/128 | ::1 | Localhost |
| Unspecified | ::/128 | :: | No address |
Common IPv6 Multicast Addresses
ff02::1 All nodes on link
ff02::2 All routers on link
ff02::1:2 All DHCP servers
IP Fragmentation
Why Fragmentation?
MTU (Maximum Transmission Unit) varies by network:
- Ethernet: 1500 bytes
- WiFi: 2304 bytes
- PPPoE: 1492 bytes
Larger packets must be fragmented to fit MTU
IPv4 Fragmentation Process
Original packet: 3000 bytes (1500 MTU)
Fragment 1:
Identification: 12345
Flags: More Fragments (MF) = 1
Offset: 0
Data: 1480 bytes
Fragment 2:
Identification: 12345
Flags: MF = 1
Offset: 185 (1480/8)
Data: 1480 bytes
Fragment 3:
Identification: 12345
Flags: MF = 0 (last fragment)
Offset: 370 (2960/8)
Data: 40 bytes
Receiver reassembles using Identification and Offset
Don't Fragment (DF) Flag
DF = 1: Don't fragment, drop if too large
Send ICMP "Fragmentation Needed" back
Used for Path MTU Discovery
IP Routing
Routing Decision Process
1. Check if destination is local (same subnet)
→ Send directly via ARP
2. If not local, find matching route:
- Check routing table for most specific match
- Use default gateway if no match
3. Send to next hop router
4. Repeat at each router until destination reached
Example Routing Table
Destination Gateway Netmask Interface
0.0.0.0 192.168.1.1 0.0.0.0 eth0 (Default)
192.168.1.0 0.0.0.0 255.255.255.0 eth0 (Local)
10.0.0.0 192.168.1.254 255.0.0.0 eth0 (Route)
Longest Prefix Match
Routing table:
10.0.0.0/8 → Gateway A
10.1.0.0/16 → Gateway B
10.1.2.0/24 → Gateway C
Packet to 10.1.2.5:
Matches all three routes
Most specific: /24
→ Use Gateway C
TTL (Time to Live)
Purpose
Prevents routing loops by limiting packet lifetime:
Source sets TTL = 64
Router 1: TTL = 63
Router 2: TTL = 62
Router 3: TTL = 61
...
Router N: TTL = 0 → Drop packet, send ICMP "Time Exceeded"
Common TTL Values
Linux: 64
Windows: 128
Cisco: 255
Can identify OS based on initial TTL
Traceroute Uses TTL
Send packet with TTL=1 → Router 1 responds
Send packet with TTL=2 → Router 2 responds
Send packet with TTL=3 → Router 3 responds
...
Maps the path to destination
IP Commands and Tools
ifconfig / ip (Linux)
# View IP configuration
ifconfig
ip addr show
# Assign IP address
sudo ifconfig eth0 192.168.1.100 netmask 255.255.255.0
sudo ip addr add 192.168.1.100/24 dev eth0
# Enable/disable interface
sudo ifconfig eth0 up
sudo ip link set eth0 up
ipconfig (Windows)
# View IP configuration
ipconfig
ipconfig /all
# Renew DHCP lease
ipconfig /renew
# Release DHCP lease
ipconfig /release
ping
# Test connectivity (ICMP Echo Request/Reply)
ping 192.168.1.1
ping -c 4 192.168.1.1 # Send 4 packets
# Test with specific packet size
ping -s 1400 192.168.1.1
# Set TTL
ping -t 10 192.168.1.1
traceroute / tracert
# Linux
traceroute google.com
# Windows
tracert google.com
# UDP traceroute (Linux)
traceroute -U google.com
# ICMP traceroute
traceroute -I google.com
netstat
# Show routing table
netstat -r
route -n
# Show all connections
netstat -an
# Show listening ports
netstat -ln
ip route
# Show routing table
ip route show
# Add static route
sudo ip route add 10.0.0.0/8 via 192.168.1.254
# Delete route
sudo ip route del 10.0.0.0/8
# Add default gateway
sudo ip route add default via 192.168.1.1
NAT (Network Address Translation)
Why NAT?
Problem: IPv4 address exhaustion
Solution: Multiple private IPs share one public IP
Private Network (192.168.1.0/24)
PC1: 192.168.1.10
PC2: 192.168.1.11 → NAT Router → Public IP: 203.0.113.5
PC3: 192.168.1.12 ↓
Tracks connections
NAT Types
1. Source NAT (SNAT)
Outbound translation:
PC (192.168.1.10:5000) → NAT → Internet (203.0.113.5:6000)
Return traffic:
Internet (203.0.113.5:6000) → NAT → PC (192.168.1.10:5000)
2. Destination NAT (DNAT) / Port Forwarding
Internet → Public IP:80 → NAT → Web Server (192.168.1.20:80)
External: 203.0.113.5:80
Internal: 192.168.1.20:80
3. PAT (Port Address Translation) / NAT Overload
PC1: 192.168.1.10:5000 → 203.0.113.5:6000
PC2: 192.168.1.11:5001 → 203.0.113.5:6001
PC3: 192.168.1.12:5002 → 203.0.113.5:6002
NAT tracks: Internal IP:Port ↔ Public Port
NAT Table Example
Internal IP:Port External IP:Port Destination
192.168.1.10:5000 203.0.113.5:6000 8.8.8.8:53
192.168.1.11:5001 203.0.113.5:6001 1.1.1.1:443
192.168.1.10:5002 203.0.113.5:6002 93.184.216.34:80
ICMP (Internet Control Message Protocol)
Part of IP suite, used for diagnostics and errors:
Common ICMP Message Types
| Type | Code | Message | Use |
|---|---|---|---|
| 0 | 0 | Echo Reply | ping response |
| 3 | 0 | Dest Network Unreachable | Routing error |
| 3 | 1 | Dest Host Unreachable | Host down |
| 3 | 3 | Dest Port Unreachable | Port closed |
| 3 | 4 | Fragmentation Needed | MTU discovery |
| 8 | 0 | Echo Request | ping |
| 11 | 0 | Time Exceeded | TTL = 0 |
| 30 | 0 | Traceroute | Traceroute packet |
Ping (ICMP Echo Request/Reply)
Client Server
| |
| ICMP Echo Request (Type 8) |
|------------------------------->|
| |
| ICMP Echo Reply (Type 0) |
|<-------------------------------|
| |
Measures round-trip time (RTT)
IP Best Practices
1. Subnet Properly
Don't use /24 for everything
- Small office: /26 (62 hosts)
- Department: /24 (254 hosts)
- Campus: /16 (65,534 hosts)
2. Reserve IP Ranges
192.168.1.1 - 192.168.1.10 Static (gateway, servers)
192.168.1.11 - 192.168.1.99 Static (printers, APs)
192.168.1.100 - 192.168.1.254 DHCP pool
3. Document Network
Maintain IP address management (IPAM)
- Which IPs are assigned
- What devices use them
- DHCP ranges
- Static assignments
4. Use Private IPs Internally
Never use public IPs internally
Use 10.x.x.x, 172.16-31.x.x, or 192.168.x.x
ELI10
IP addresses are like street addresses for computers:
IPv4 Address (192.168.1.100):
- Like a home address with 4 numbers
- Each number is 0-255
- Almost ran out of addresses (like running out of street addresses in a city)
IPv6 Address (2001:db8::1):
- New address system with way more numbers
- Like adding ZIP+4 codes, apartment numbers, floor numbers
- So many addresses we'll never run out
Routing:
- Routers are like mail sorting facilities
- They look at the address and send the packet closer to its destination
- Each router knows which direction to send packets
Private IPs:
- Like apartment numbers (Apt 101, 102, 103)
- Work inside the building (local network)
- NAT is like the building's street address (everyone shares it for mail)
Further Resources
- RFC 791 - IPv4 Specification
- RFC 8200 - IPv6 Specification
- RFC 1918 - Private Address Space
- Subnet Calculator
- CIDR to IPv4 Conversion
IPv4 (Internet Protocol version 4)
Overview
IPv4 (Internet Protocol version 4) is the fourth version of the Internet Protocol and the first version to be widely deployed. It is the network layer protocol responsible for addressing and routing packets across networks, providing the addressing scheme that allows devices to find each other on the internet.
Key Characteristics
| Feature | IPv4 |
|---|---|
| Address Size | 32 bits |
| Address Format | Decimal dotted notation (192.168.1.1) |
| Total Addresses | ~4.3 billion (2³²) |
| Header Size | 20-60 bytes (variable) |
| Checksum | Yes (header checksum) |
| Fragmentation | By routers and source |
| Broadcast | Yes |
| Configuration | Manual or DHCP |
| IPSec | Optional |
IPv4 Packet Format
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|Version| IHL |Type of Service| Total Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Identification |Flags| Fragment Offset |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Time to Live | Protocol | Header Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Destination Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Options (if IHL > 5) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Data |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
IPv4 Header Fields
- Version (4 bits): IP version (4 for IPv4)
- IHL (4 bits): Internet Header Length (5-15, in 32-bit words)
- Minimum: 5 (20 bytes)
- Maximum: 15 (60 bytes)
- Type of Service (8 bits): QoS, priority
- Precedence (3 bits): Priority level
- Delay, Throughput, Reliability (3 bits)
- Used for traffic prioritization
- Total Length (16 bits): Entire packet size including header (max 65,535 bytes)
- Identification (16 bits): Fragment identification
- All fragments of the same packet share this value
- Flags (3 bits):
- Bit 0: Reserved (must be 0)
- Bit 1: Don't Fragment (DF) - prevents fragmentation
- Bit 2: More Fragments (MF) - indicates more fragments follow
- Fragment Offset (13 bits): Position of fragment in original packet (in 8-byte units)
- Time to Live (TTL) (8 bits): Max hops (decremented at each router)
- Prevents infinite routing loops
- Typical initial values: 64 (Linux), 128 (Windows), 255 (Cisco)
- Protocol (8 bits): Upper layer protocol
- 1 = ICMP
- 6 = TCP
- 17 = UDP
- Header Checksum (16 bits): Error detection for header only
- Recalculated at each hop (because TTL changes)
- Source Address (32 bits): Sender IPv4 address
- Destination Address (32 bits): Receiver IPv4 address
- Options (variable): Rarely used today
- Security, timestamps, route recording, source routing
IPv4 Address Classes
Traditional Class System (Obsolete, replaced by CIDR)
The classful addressing system divided the IPv4 address space into five classes (A-E), but this system was wasteful and is now obsolete. It's still useful to understand for historical reasons.
Class A: 0.0.0.0 to 127.255.255.255 /8 (16,777,214 hosts)
Network: 8 bits, Host: 24 bits
First bit: 0
Example: 10.0.0.0/8
Class B: 128.0.0.0 to 191.255.255.255 /16 (65,534 hosts)
Network: 16 bits, Host: 16 bits
First two bits: 10
Example: 172.16.0.0/16
Class C: 192.0.0.0 to 223.255.255.255 /24 (254 hosts)
Network: 24 bits, Host: 8 bits
First three bits: 110
Example: 192.168.1.0/24
Class D: 224.0.0.0 to 239.255.255.255 (Multicast)
First four bits: 1110
Used for multicast groups
Class E: 240.0.0.0 to 255.255.255.255 (Reserved)
First four bits: 1111
Reserved for experimental use
Why Classes Were Abandoned
Problem: Wasteful allocation
- Small company needs 300 hosts
- Class C (/24): Only 254 hosts (too small)
- Class B (/16): 65,534 hosts (massive waste)
Solution: CIDR (Classless Inter-Domain Routing)
- Flexible subnet sizes
- Better address utilization
Private IP Address Ranges
Private IP addresses are reserved for use in private networks and are not routed on the public internet. Network Address Translation (NAT) is required to access the internet.
10.0.0.0 - 10.255.255.255 (10.0.0.0/8)
16,777,216 addresses
Typically used in large enterprises
172.16.0.0 - 172.31.255.255 (172.16.0.0/12)
1,048,576 addresses
Medium-sized networks
192.168.0.0 - 192.168.255.255 (192.168.0.0/16)
65,536 addresses
Home and small office networks
Advantages of Private Addresses
- Address Conservation: Reuse addresses across different private networks
- Security: Not directly accessible from the internet
- Flexibility: Can use any addressing scheme internally
- Cost: No need to purchase public IP addresses
Special IPv4 Addresses
0.0.0.0/8 - Current network (only valid as source)
Used during boot before IP is configured
127.0.0.0/8 - Loopback addresses
127.0.0.1 = localhost (most common)
Packets sent to loopback never leave the host
169.254.0.0/16 - Link-local addresses (APIPA)
Auto-assigned when DHCP fails
169.254.0.0 and 169.254.255.255 reserved
192.0.2.0/24 - Documentation/examples (TEST-NET-1)
198.51.100.0/24 - Documentation (TEST-NET-2)
203.0.113.0/24 - Documentation (TEST-NET-3)
Safe to use in documentation, never routed
192.88.99.0/24 - 6to4 Relay Anycast (IPv6 transition)
198.18.0.0/15 - Benchmark testing
Network device testing
224.0.0.0/4 - Multicast (Class D)
224.0.0.0 - 239.255.255.255
255.255.255.255 - Limited broadcast
Sent to all hosts on local network segment
CIDR (Classless Inter-Domain Routing)
CIDR replaced the classful addressing system, providing flexible subnetting and efficient address allocation.
CIDR Notation
192.168.1.0/24
^^
Number of network bits (subnet mask length)
/24 = 255.255.255.0 netmask
24 bits for network, 8 bits for hosts
2^8 = 256 total addresses
2^8 - 2 = 254 usable host addresses
(Network address and broadcast address reserved)
CIDR Notation Breakdown
192.168.1.0/24
Binary:
11000000.10101000.00000001.00000000 (IP address)
11111111.11111111.11111111.00000000 (Subnet mask /24)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
24 network bits (1s) ^^^^^^^^
8 host bits (0s)
Network portion: 192.168.1
Host portion: 0-255
Common Subnet Masks
| CIDR | Netmask | Wildcard | Total | Usable | Use Case |
|---|---|---|---|---|---|
| /8 | 255.0.0.0 | 0.255.255.255 | 16,777,216 | 16,777,214 | Huge networks (Class A) |
| /12 | 255.240.0.0 | 0.15.255.255 | 1,048,576 | 1,048,574 | Large ISPs |
| /16 | 255.255.0.0 | 0.0.255.255 | 65,536 | 65,534 | Large networks (Class B) |
| /20 | 255.255.240.0 | 0.0.15.255 | 4,096 | 4,094 | Medium businesses |
| /24 | 255.255.255.0 | 0.0.0.255 | 256 | 254 | Small networks (Class C) |
| /25 | 255.255.255.128 | 0.0.0.127 | 128 | 126 | Subnet split |
| /26 | 255.255.255.192 | 0.0.0.63 | 64 | 62 | Small subnet |
| /27 | 255.255.255.224 | 0.0.0.31 | 32 | 30 | Very small |
| /28 | 255.255.255.240 | 0.0.0.15 | 16 | 14 | Tiny subnet |
| /29 | 255.255.255.248 | 0.0.0.7 | 8 | 6 | Minimal |
| /30 | 255.255.255.252 | 0.0.0.3 | 4 | 2 | Point-to-point links |
| /31 | 255.255.255.254 | 0.0.0.1 | 2 | 2 | Point-to-point (RFC 3021) |
| /32 | 255.255.255.255 | 0.0.0.0 | 1 | 1 | Single host route |
Subnet Calculation Example
Network: 192.168.1.0/24
Binary calculation:
IP: 11000000.10101000.00000001.00000000
Mask: 11111111.11111111.11111111.00000000
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Network bits
^^^^^^^^ Host bits
Network Address: 192.168.1.0 (all host bits = 0)
First Usable: 192.168.1.1 (first host)
Last Usable: 192.168.1.254 (last host)
Broadcast Address: 192.168.1.255 (all host bits = 1)
Total Addresses: 2^8 = 256
Usable Hosts: 256 - 2 = 254
(exclude network and broadcast)
Subnet Mask Calculation
To calculate subnet mask from CIDR:
/24 in binary:
11111111.11111111.11111111.00000000
^^^^^^^^ ^^^^^^^^ ^^^^^^^^ ^^^^^^^^
255 255 255 0
/26 in binary:
11111111.11111111.11111111.11000000
^^^^^^^^ ^^^^^^^^ ^^^^^^^^ ^^^^^^^^
255 255 255 192
Shortcut:
/24 = 256 - 2^(32-24) = 256 - 2^8 = 256 - 256 = 0 (last octet)
/26 = 256 - 2^(32-26) = 256 - 2^6 = 256 - 64 = 192 (last octet)
Subnetting Example
Original Network: 192.168.1.0/24 (254 usable hosts)
Requirement: Split into 4 equal subnets
Calculation:
- Need 4 subnets = 2^2 subnets
- Borrow 2 bits from host portion
- New mask: /24 + 2 = /26
- Each subnet: 2^6 = 64 addresses, 62 usable
Result:
Subnet 1: 192.168.1.0/26
Network: 192.168.1.0
First Host: 192.168.1.1
Last Host: 192.168.1.62
Broadcast: 192.168.1.63
Subnet 2: 192.168.1.64/26
Network: 192.168.1.64
First Host: 192.168.1.65
Last Host: 192.168.1.126
Broadcast: 192.168.1.127
Subnet 3: 192.168.1.128/26
Network: 192.168.1.128
First Host: 192.168.1.129
Last Host: 192.168.1.190
Broadcast: 192.168.1.191
Subnet 4: 192.168.1.192/26
Network: 192.168.1.192
First Host: 192.168.1.193
Last Host: 192.168.1.254
Broadcast: 192.168.1.255
Variable Length Subnet Masking (VLSM)
VLSM allows different subnet sizes within the same network:
Main Network: 10.0.0.0/8
Allocations:
Department A (needs 500 hosts): 10.0.0.0/23 (510 hosts)
Department B (needs 200 hosts): 10.0.2.0/24 (254 hosts)
Department C (needs 100 hosts): 10.0.3.0/25 (126 hosts)
Point-to-point link: 10.0.3.128/30 (2 hosts)
Server subnet: 10.0.4.0/28 (14 hosts)
Benefits:
- Efficient address utilization
- Minimizes waste
- Flexible network design
IP Fragmentation
Why Fragmentation?
Every network has a Maximum Transmission Unit (MTU) that limits packet size:
Common MTU Values:
- Ethernet: 1500 bytes
- Wi-Fi: 2304 bytes
- PPPoE: 1492 bytes
- VPN: 1400 bytes (varies)
- Jumbo Frames: 9000 bytes
When packet > MTU:
- Must be fragmented to fit
- Or dropped if DF flag is set
IPv4 Fragmentation Process
Fragmentation can occur at the source or any router along the path:
Original Packet: 3000 bytes data + 20 byte header = 3020 bytes
MTU: 1500 bytes
Data per fragment: 1500 - 20 (header) = 1480 bytes
Fragment 1:
IP Header (20 bytes)
Identification: 12345
Flags: MF = 1 (More Fragments)
Fragment Offset: 0
Total Length: 1500
Data: 1480 bytes
Fragment 2:
IP Header (20 bytes)
Identification: 12345
Flags: MF = 1
Fragment Offset: 185 (1480/8 = 185)
Total Length: 1500
Data: 1480 bytes
Fragment 3:
IP Header (20 bytes)
Identification: 12345
Flags: MF = 0 (Last Fragment)
Fragment Offset: 370 (2960/8 = 370)
Total Length: 60
Data: 40 bytes
Receiver:
1. Receives all three fragments
2. Checks Identification field (12345) to group them
3. Uses Fragment Offset to order them
4. Reassembles when MF = 0 (last fragment received)
Fragment Offset Calculation
Fragment Offset is in 8-byte units:
Fragment 1: Offset 0 → Bytes 0-1479
Fragment 2: Offset 185 → Bytes 1480-2959 (185 × 8 = 1480)
Fragment 3: Offset 370 → Bytes 2960-2999 (370 × 8 = 2960)
Why 8-byte units?
- 13 bits for offset = max 8191
- 8191 × 8 = 65,528 bytes
- Covers max IP packet size (65,535 bytes)
Don't Fragment (DF) Flag
DF = 0: Allow fragmentation
Router can fragment if needed
DF = 1: Don't fragment
Router drops packet if too large
Sends ICMP "Fragmentation Needed" back to source
ICMP Type 3, Code 4:
- Includes MTU of the link
- Source can adjust packet size
Used for Path MTU Discovery (PMTUD)
Path MTU Discovery (PMTUD)
Process:
1. Source sends packet with DF=1 and large size
2. If too large, router drops and sends ICMP
3. Source reduces packet size and retries
4. Repeat until packet gets through
5. Source now knows the path MTU
Example:
Source → [MTU 1500] → Router A → [MTU 1400] → Router B → Dest
1. Send 1500-byte packet, DF=1
2. Router B drops it, sends ICMP: "Frag needed, MTU=1400"
3. Source retries with 1400-byte packets
4. Success! Path MTU = 1400
Fragmentation Issues
Problems:
1. Performance overhead (reassembly)
2. Lost fragment = entire packet lost
3. Difficulty for firewalls/NAT
4. Security concerns (fragment attacks)
Best Practices:
- Avoid fragmentation when possible
- Use TCP MSS clamping
- Enable PMTUD
- Consider smaller packet sizes
TTL (Time to Live)
Purpose
TTL prevents routing loops by limiting packet lifetime:
Without TTL:
Router A → Router B → Router C → Router A → ...
Packet loops forever, congesting network
With TTL:
Source sets TTL = 64
Router 1: Decrements to 63
Router 2: Decrements to 62
...
Router 64: Decrements to 0
→ Drops packet
→ Sends ICMP "Time Exceeded" to source
Common TTL Values
Different operating systems use different initial TTL values:
Operating System Initial TTL
Linux 64
Windows 128
Cisco IOS 255
FreeBSD 64
Mac OS X 64
Solaris 255
Security Note:
Can fingerprint OS based on received TTL
Received TTL = Initial TTL - Hop Count
TTL Example
Packet journey from Source to Destination:
Source (TTL=64)
|
v
Router 1 (TTL=63) → Decrements TTL
|
v
Router 2 (TTL=62) → Decrements TTL
|
v
Router 3 (TTL=61) → Decrements TTL
|
v
Destination (TTL=60) → Receives packet
Reverse calculation:
- Received packet with TTL=60
- If initial TTL was 64
- Hop count = 64 - 60 = 4 hops
Traceroute Uses TTL
Traceroute maps network paths by manipulating TTL:
How traceroute works:
1. Send packet with TTL=1
→ First router decrements to 0
→ Router drops packet
→ Router sends ICMP "Time Exceeded"
→ We learn first router IP
2. Send packet with TTL=2
→ First router: TTL=1
→ Second router: TTL=0
→ Second router responds
→ We learn second router IP
3. Send packet with TTL=3
→ Continue until destination reached
Result: Map of all routers in path
Example output:
1 192.168.1.1 2ms
2 10.0.0.1 5ms
3 203.0.113.1 10ms
4 93.184.216.34 15ms (destination)
Linux Traceroute Example
# Default (UDP probes)
traceroute google.com
# ICMP probes
traceroute -I google.com
# TCP SYN probes to port 80
traceroute -T -p 80 google.com
# Set max hops
traceroute -m 20 google.com
# Send 3 probes per hop (default)
traceroute -q 3 google.com
Windows Tracert Example
# ICMP probes (Windows default)
tracert google.com
# Set max hops
tracert -h 20 google.com
# Don't resolve addresses to hostnames
tracert -d google.com
IP Routing
Routing Decision Process
When a host needs to send an IP packet:
1. Determine if destination is local:
- Perform bitwise AND of dest IP and subnet mask
- Compare with local network address
Example:
Local IP: 192.168.1.100/24
Dest IP: 192.168.1.50
192.168.1.50 AND 255.255.255.0 = 192.168.1.0 (matches local network)
→ Send directly via ARP
2. If destination is remote:
- Search routing table for matching route
- Use longest prefix match algorithm
- Forward to gateway for that route
3. If no specific route matches:
- Use default gateway (0.0.0.0/0)
4. If no default gateway:
- Destination unreachable error
Example Routing Table
Destination Gateway Netmask Interface Metric
0.0.0.0 192.168.1.1 0.0.0.0 eth0 100 (Default route)
10.0.0.0 192.168.1.254 255.0.0.0 eth0 10 (Static route)
192.168.1.0 0.0.0.0 255.255.255.0 eth0 0 (Connected)
192.168.2.0 192.168.1.200 255.255.255.0 eth0 20 (Static route)
172.16.0.0 192.168.1.254 255.255.0.0 eth0 15 (Static route)
Routing Table Lookup
Packet destination: 10.1.2.5
Routing table:
0.0.0.0/0 → Gateway A (Default route)
10.0.0.0/8 → Gateway B (Matches!)
10.1.0.0/16 → Gateway C (Matches! More specific)
10.1.2.0/24 → Gateway D (Matches! Most specific)
192.168.1.0/24 → Local (No match)
Longest Prefix Match Algorithm:
- All routes compared
- Most specific match wins (/24 > /16 > /8 > /0)
- Forward to Gateway D
Viewing Routing Table
# Linux - traditional
route -n
netstat -rn
# Linux - modern
ip route show
ip route list
# Windows
route print
netstat -r
# Example output (Linux):
Destination Gateway Genmask Flags Metric Ref Use Iface
0.0.0.0 192.168.1.1 0.0.0.0 UG 100 0 0 eth0
192.168.1.0 0.0.0.0 255.255.255.0 U 0 0 0 eth0
NAT (Network Address Translation)
Why NAT?
Problem: IPv4 Address Exhaustion
- Only ~4.3 billion addresses
- Internet growth exceeded availability
- Need to conserve public IP addresses
Solution: NAT
- Private network uses private IPs (10.x, 172.16-31.x, 192.168.x)
- Single public IP shared by many devices
- Router translates between private and public
How NAT Works
Private Network (192.168.1.0/24)
┌──────────────────────────────┐
│ PC1: 192.168.1.10 │
│ PC2: 192.168.1.11 │──→ NAT Router ──→ Internet
│ PC3: 192.168.1.12 │ (Translates) Public IP: 203.0.113.5
└──────────────────────────────┘
Outbound:
PC1 (192.168.1.10:5000) → NAT → Internet as (203.0.113.5:6000)
Inbound:
Internet → (203.0.113.5:6000) → NAT → PC1 (192.168.1.10:5000)
NAT maintains translation table to track connections
NAT Types
1. Static NAT (One-to-One)
One private IP ↔ One public IP
Configuration:
Private: 192.168.1.10 ↔ Public: 203.0.113.10
Private: 192.168.1.11 ↔ Public: 203.0.113.11
Use case:
- Web servers
- Mail servers
- Devices that need incoming connections
2. Dynamic NAT (Many-to-Many)
Multiple private IPs ↔ Pool of public IPs
Configuration:
Private: 192.168.1.0/24
Public pool: 203.0.113.10 - 203.0.113.20
Connection:
PC1 (192.168.1.10) → Gets 203.0.113.10
PC2 (192.168.1.11) → Gets 203.0.113.11
PC3 (192.168.1.12) → Gets 203.0.113.12
When PC1 disconnects, 203.0.113.10 returns to pool
3. PAT (Port Address Translation) / NAT Overload
Most common type, used in home routers:
Many private IPs ↔ Single public IP (different ports)
Translation table:
Internal IP:Port External IP:Port Remote IP:Port
192.168.1.10:5000 → 203.0.113.5:6000 → 8.8.8.8:53
192.168.1.11:5001 → 203.0.113.5:6001 → 1.1.1.1:443
192.168.1.12:5002 → 203.0.113.5:6002 → 93.184.216.34:80
192.168.1.10:5003 → 203.0.113.5:6003 → 142.250.185.46:443
Note: Same internal IP can have multiple external ports
4. Port Forwarding (DNAT - Destination NAT)
Allow external connections to internal servers:
Configuration:
External: 203.0.113.5:80 → Internal: 192.168.1.20:80 (Web)
External: 203.0.113.5:443 → Internal: 192.168.1.20:443 (HTTPS)
External: 203.0.113.5:22 → Internal: 192.168.1.25:22 (SSH)
External: 203.0.113.5:3389 → Internal: 192.168.1.30:3389 (RDP)
Internet request to 203.0.113.5:80
→ Router forwards to 192.168.1.20:80
→ Web server responds
→ Router translates source back to 203.0.113.5:80
NAT Translation Table Example
Protocol Inside Local Inside Global Outside Local Outside Global
TCP 192.168.1.10:5000 203.0.113.5:6000 8.8.8.8:53 8.8.8.8:53
TCP 192.168.1.11:5001 203.0.113.5:6001 1.1.1.1:443 1.1.1.1:443
TCP 192.168.1.10:5002 203.0.113.5:6002 93.184.216.34:80 93.184.216.34:80
Terminology:
- Inside Local: Private IP (before NAT)
- Inside Global: Public IP (after NAT)
- Outside Local: Remote IP (before NAT)
- Outside Global: Remote IP (after NAT)
NAT Configuration Examples
Linux (iptables)
# Enable IP forwarding
echo 1 > /proc/sys/net/ipv4/ip_forward
# Basic NAT (masquerade)
iptables -t nat -A POSTROUTING -o eth0 -j MASQUERADE
# Or with specific IP
iptables -t nat -A POSTROUTING -o eth0 -j SNAT --to-source 203.0.113.5
# Port forwarding
iptables -t nat -A PREROUTING -i eth0 -p tcp --dport 80 \
-j DNAT --to-destination 192.168.1.20:80
# View NAT table
iptables -t nat -L -v
Cisco Router
! Enable NAT
interface GigabitEthernet0/0
ip nat outside
interface GigabitEthernet0/1
ip nat inside
! NAT overload (PAT)
ip nat inside source list 1 interface GigabitEthernet0/0 overload
access-list 1 permit 192.168.1.0 0.0.0.255
! Port forwarding
ip nat inside source static tcp 192.168.1.20 80 203.0.113.5 80
! View NAT translations
show ip nat translations
show ip nat statistics
NAT Disadvantages
1. Breaks end-to-end connectivity
- Some protocols don't work (FTP active mode, SIP, H.323)
- Requires ALG (Application Layer Gateway) for some apps
2. Performance overhead
- Translation takes CPU time
- Maintains state tables
3. Complicates peer-to-peer
- NAT traversal techniques needed (STUN, TURN, ICE)
4. Hides internal topology
- All traffic appears from one IP
- Makes troubleshooting harder
5. Limited by port numbers
- 65,535 ports per public IP
- In practice, ~4000 concurrent connections
ICMP (Internet Control Message Protocol)
ICMP is a network layer protocol used for diagnostics and error reporting. It's an integral part of IP.
ICMP Message Format
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type | Code | Checksum |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| Message Body |
| (varies by type) |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Common ICMP Message Types
| Type | Code | Message | Description | Use |
|---|---|---|---|---|
| 0 | 0 | Echo Reply | Response to ping | ping response |
| 3 | 0 | Destination Network Unreachable | Cannot reach network | Routing error |
| 3 | 1 | Destination Host Unreachable | Cannot reach host | Host down/filtered |
| 3 | 2 | Destination Protocol Unreachable | Protocol not supported | Protocol error |
| 3 | 3 | Destination Port Unreachable | Port not listening | Port closed |
| 3 | 4 | Fragmentation Needed and DF Set | MTU exceeded with DF=1 | PMTUD |
| 3 | 13 | Communication Administratively Prohibited | Filtered by firewall | ACL/firewall |
| 5 | 0-3 | Redirect | Better route available | Route optimization |
| 8 | 0 | Echo Request | Ping request | ping |
| 11 | 0 | Time Exceeded in Transit | TTL reached 0 | traceroute |
| 11 | 1 | Fragment Reassembly Time Exceeded | Fragments timeout | Fragmentation issue |
| 12 | 0 | Parameter Problem | IP header error | Malformed packet |
Ping (ICMP Echo Request/Reply)
Ping tests connectivity and measures round-trip time:
Client Server
| |
| ICMP Echo Request (Type 8) |
| Identifier: 1234 |
| Sequence: 1 |
| Data: 56 bytes |
|------------------------------->|
| |
| ICMP Echo Reply (Type 0) |
| Identifier: 1234 |
| Sequence: 1 |
| Data: 56 bytes (echoed) |
|<-------------------------------|
| |
Round-Trip Time (RTT) measured
Ping Examples
# Basic ping
ping 8.8.8.8
# Send specific number of packets
ping -c 4 8.8.8.8
# Set packet size
ping -s 1000 8.8.8.8
# Set interval (0.2 seconds)
ping -i 0.2 8.8.8.8
# Flood ping (requires root)
sudo ping -f 8.8.8.8
# Set TTL
ping -t 5 8.8.8.8
# Disable DNS resolution
ping -n 8.8.8.8
# Example output:
PING 8.8.8.8 (8.8.8.8) 56(84) bytes of data.
64 bytes from 8.8.8.8: icmp_seq=1 ttl=117 time=10.2 ms
64 bytes from 8.8.8.8: icmp_seq=2 ttl=117 time=9.8 ms
64 bytes from 8.8.8.8: icmp_seq=3 ttl=117 time=10.1 ms
--- 8.8.8.8 ping statistics ---
3 packets transmitted, 3 received, 0% packet loss, time 2003ms
rtt min/avg/max/mdev = 9.8/10.0/10.2/0.2 ms
ICMP in Traceroute
Traceroute sends packets with increasing TTL:
Packet 1: TTL=1
→ Router 1 decrements to 0
→ Router 1 sends ICMP Type 11 (Time Exceeded)
→ Reveals Router 1 IP
Packet 2: TTL=2
→ Router 1: TTL=1
→ Router 2: TTL=0
→ Router 2 sends ICMP Type 11
→ Reveals Router 2 IP
Packet N: TTL=N
→ Destination reached
→ Sends ICMP Type 3 (Port Unreachable) or Echo Reply
→ Traceroute completes
IPv4 Commands and Tools
ifconfig / ip (Linux)
# View IP configuration (old style)
ifconfig
# View IP configuration (modern)
ip addr show
ip a
# Show specific interface
ip addr show eth0
# Assign IP address (temporary)
sudo ip addr add 192.168.1.100/24 dev eth0
# Remove IP address
sudo ip addr del 192.168.1.100/24 dev eth0
# Enable interface
sudo ip link set eth0 up
# Disable interface
sudo ip link set eth0 down
# Show interface statistics
ip -s link show eth0
ipconfig (Windows)
# View IP configuration
ipconfig
# View detailed configuration
ipconfig /all
# Renew DHCP lease
ipconfig /renew
# Release DHCP lease
ipconfig /release
# Flush DNS cache
ipconfig /flushdns
# Display DNS cache
ipconfig /displaydns
ip route (Linux)
# Show routing table
ip route show
ip route list
# Add static route
sudo ip route add 10.0.0.0/8 via 192.168.1.254
# Add route via specific interface
sudo ip route add 10.0.0.0/8 dev eth0
# Delete route
sudo ip route del 10.0.0.0/8
# Add default gateway
sudo ip route add default via 192.168.1.1
# Delete default gateway
sudo ip route del default
# Change route
sudo ip route change 10.0.0.0/8 via 192.168.1.253
# Show route to specific destination
ip route get 8.8.8.8
route (Linux/Windows)
# Linux - show routing table
route -n
# Linux - add route
sudo route add -net 10.0.0.0/8 gw 192.168.1.254
# Linux - delete route
sudo route del -net 10.0.0.0/8
# Windows - show routing table
route print
# Windows - add route
route add 10.0.0.0 mask 255.0.0.0 192.168.1.254
# Windows - delete route
route delete 10.0.0.0
# Windows - add persistent route
route -p add 10.0.0.0 mask 255.0.0.0 192.168.1.254
arp (Address Resolution Protocol)
# View ARP cache
arp -a
# View ARP cache for specific interface (Linux)
arp -i eth0
# Add static ARP entry (Linux)
sudo arp -s 192.168.1.50 00:11:22:33:44:55
# Delete ARP entry (Linux)
sudo arp -d 192.168.1.50
# View ARP cache (modern Linux)
ip neigh show
# Delete ARP entry (modern Linux)
sudo ip neigh del 192.168.1.50 dev eth0
IPv4 Best Practices
1. Subnet Design
Plan network hierarchy:
Organization: 10.0.0.0/8
├── Location A: 10.1.0.0/16
│ ├── Servers: 10.1.1.0/24
│ ├── Workstations: 10.1.2.0/24
│ └── Guests: 10.1.3.0/24
├── Location B: 10.2.0.0/16
│ ├── Servers: 10.2.1.0/24
│ └── Workstations: 10.2.2.0/24
└── Management: 10.255.0.0/16
├── Network Devices: 10.255.1.0/24
└── Out-of-band: 10.255.2.0/24
Benefits:
- Logical organization
- Summarization for routing
- Security segmentation
- Growth flexibility
2. IP Address Allocation
Reserve ranges within each subnet:
Example subnet: 192.168.1.0/24
192.168.1.0 Network address (reserved)
192.168.1.1 Gateway (router)
192.168.1.2-10 Infrastructure (switches, APs)
192.168.1.11-50 Servers (static)
192.168.1.51-99 Printers/IoT (static)
192.168.1.100-254 DHCP pool (dynamic)
192.168.1.255 Broadcast address (reserved)
Document everything in IPAM (IP Address Management) system
3. Use Private IP Ranges
ALWAYS use private IPs internally:
Small networks: 192.168.x.0/24
Medium networks: 172.16.x.0/16 to 172.31.x.0/16
Large networks: 10.0.0.0/8
NEVER use:
- Public IPs internally (causes routing issues)
- TEST-NET ranges (192.0.2.0/24, 198.51.100.0/24, 203.0.113.0/24)
- Multicast ranges (224.0.0.0/4)
4. Network Documentation
Maintain detailed documentation:
Network Diagram:
- Physical topology
- Logical topology
- IP addressing scheme
- VLAN assignments
Spreadsheet/IPAM:
IP Address | Hostname | MAC Address | Type | Notes
192.168.1.1 | gateway | 00:11:22:33:44:55 | Router | Primary gateway
192.168.1.10 | server1 | 00:11:22:33:44:66 | Server | Web server
192.168.1.11 | server2 | 00:11:22:33:44:77 | Server | Database
192.168.1.50 | printer1 | 00:11:22:33:44:88 | Printer| HP LaserJet
5. DHCP Configuration
DHCP best practices:
- Appropriate lease time:
* Office: 8-24 hours
* Guest: 1-4 hours
* Mobile: 30-60 minutes
- Reserve space for static IPs
- Configure DHCP options:
* Option 3: Default gateway
* Option 6: DNS servers
* Option 15: Domain name
* Option 42: NTP servers
- Redundant DHCP servers (split scope or failover)
- Monitor DHCP scope utilization
6. Network Security
Security measures:
1. Subnetting for segmentation
- Separate user, server, management networks
- Use VLANs
2. Private IPs + NAT
- Hide internal topology
- Conserve public IPs
3. Disable unused services
- No ICMP redirect
- No source routing
4. Ingress/egress filtering
- Block spoofed source IPs
- RFC 3330 filtering
5. Monitor for IP conflicts
- Detect ARP spoofing
- DHCP snooping
7. Avoid IP Conflicts
Prevention:
1. Use DHCP for workstations
2. Static IPs for servers/infrastructure
3. Document all static assignments
4. Configure DHCP exclusions for static range
5. Use DHCP reservations for semi-static hosts
6. Enable IP conflict detection
Detection:
- arping before assigning static IP
- Monitor DHCP logs
- Use network scanning tools
- Enable DHCP conflict detection
ELI10: IPv4 Explained Simply
Think of IPv4 addresses like street addresses for computers:
IPv4 Address (192.168.1.100)
- Like a home address with 4 numbers
- Each number is between 0 and 255
- Separated by dots
- Uniquely identifies your computer on the network
Why 4 Numbers?
Each number is 0-255 (256 possibilities)
256 × 256 × 256 × 256 = 4.3 billion addresses
Problem: We almost ran out!
- Too many computers, phones, tablets
- Solution: NAT (share one public address)
- Future: IPv6 (way more addresses)
Private vs Public
Private IPs (like apartment numbers):
- 192.168.x.x (home networks)
- 10.x.x.x (big companies)
- Only work inside your building (network)
Public IPs (like street addresses):
- Work on the internet
- Must be unique worldwide
- Expensive and limited
Subnets
Like organizing streets into neighborhoods:
City: 10.0.0.0/8 (whole city)
└─ Neighborhood: 10.1.0.0/16 (one area)
└─ Street: 10.1.1.0/24 (one street)
└─ House: 10.1.1.100 (your house)
/24 means: First 3 numbers are the "street", last number is your "house number"
NAT (Sharing One Address)
Your home:
- Router has public IP: 203.0.113.5 (street address)
- Devices have private IPs: 192.168.1.x (apartment numbers)
- Router is like mailroom: forwards mail to right apartment
Routing
Routers are like mail sorting facilities:
- Look at destination address
- Decide which direction to send packet
- Pass to next router
- Repeat until destination reached
Further Resources
- RFC 791 - IPv4 Specification
- RFC 1918 - Private Address Space
- RFC 950 - Internet Standard Subnetting Procedure
- RFC 1812 - Requirements for IPv4 Routers
- RFC 3021 - Using 31-Bit Prefixes on IPv4 Point-to-Point Links
- Subnet Calculator
- CIDR to IPv4 Conversion
- IANA IPv4 Address Space Registry
IPv6 (Internet Protocol version 6)
Overview
IPv6 (Internet Protocol version 6) is the most recent version of the Internet Protocol. It was developed to address the IPv4 address exhaustion problem and to provide improvements in routing, security, and network auto-configuration. IPv6 is designed to replace IPv4 and is the future of internet addressing.
Key Characteristics
| Feature | IPv6 |
|---|---|
| Address Size | 128 bits |
| Address Format | Hexadecimal colon notation (2001:db8::1) |
| Total Addresses | 340 undecillion (2¹²⁸ ≈ 3.4 × 10³⁸) |
| Header Size | 40 bytes (fixed, no options) |
| Checksum | No (delegated to link and transport layers) |
| Fragmentation | Source host only (not by routers) |
| Broadcast | No (replaced by multicast) |
| Configuration | SLAAC (Stateless Auto-Config) or DHCPv6 |
| IPSec | Mandatory (built-in security) |
| Address Resolution | NDP (Neighbor Discovery Protocol) instead of ARP |
IPv6 Advantages Over IPv4
1. Vast Address Space
- 340 undecillion addresses
- Every grain of sand on Earth could have billions of IPs
- No more address exhaustion
2. Simplified Header
- Fixed 40-byte header (no options)
- Faster processing by routers
- Extension headers for optional features
3. Auto-Configuration
- SLAAC: hosts configure themselves
- No DHCP required (though DHCPv6 available)
- Plug-and-play networking
4. Built-in Security
- IPSec mandatory
- Authentication and encryption
- Better privacy features
5. Better Routing
- Hierarchical addressing
- Smaller routing tables
- More efficient routing
6. No NAT Required
- Every device gets public address
- True end-to-end connectivity
- Simplifies protocols (VoIP, gaming, P2P)
7. Multicast Improvements
- No broadcast (more efficient)
- Built-in multicast support
- Scope-based addressing
IPv6 Packet Format
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|Version| Traffic Class | Flow Label |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Payload Length | Next Header | Hop Limit |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
+ +
| |
+ Source Address +
| (128 bits) |
+ +
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
+ +
| |
+ Destination Address +
| (128 bits) |
+ +
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
+ Extension Headers +
| (if present) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
+ Payload +
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
IPv6 Header Fields (40 bytes fixed)
- Version (4 bits): IP version = 6
- Traffic Class (8 bits): QoS and priority
- Differentiated Services Code Point (DSCP): 6 bits
- Explicit Congestion Notification (ECN): 2 bits
- Similar to IPv4 Type of Service
- Flow Label (20 bits): QoS flow identification
- Identifies packets belonging to same flow
- Used for QoS and ECMP (Equal-Cost Multi-Path)
- Routers can treat flows differently
- Payload Length (16 bits): Length of data after header
- Does NOT include the 40-byte header itself
- Maximum: 65,535 bytes
- Jumbograms (>65,535) use Hop-by-Hop extension
- Next Header (8 bits): Type of next header
- Like IPv4 Protocol field
- Values: 6=TCP, 17=UDP, 58=ICMPv6, 59=No next header
- Or indicates extension header type
- Hop Limit (8 bits): Maximum hops (like IPv4 TTL)
- Decremented by each router
- Packet dropped when reaches 0
- Typical values: 64, 128, 255
- Source Address (128 bits): Sender IPv6 address
- Destination Address (128 bits): Receiver IPv6 address
Comparison with IPv4 Header
Removed from IPv4:
- Header Length (IHL): Fixed at 40 bytes
- Identification, Flags, Fragment Offset: Moved to extension header
- Header Checksum: Redundant (link and transport layers handle it)
- Options: Replaced by extension headers
Added to IPv6:
- Flow Label: QoS identification
Renamed:
- TTL → Hop Limit
- Protocol → Next Header
- Type of Service → Traffic Class
IPv6 Address Format
Full Representation
2001:0db8:0000:0042:0000:8a2e:0370:7334
Structure:
- 8 groups of 4 hexadecimal digits
- Separated by colons
- Each group = 16 bits
- Total: 128 bits
Address Compression Rules
Rule 1: Remove Leading Zeros
Original:
2001:0db8:0000:0042:0000:8a2e:0370:7334
After removing leading zeros:
2001:db8:0:42:0:8a2e:370:7334
Each group can have 1-4 hex digits
Rule 2: Compress Consecutive Zeros with ::
Before: 2001:db8:0:42:0:8a2e:370:7334
After: 2001:db8:0:42::8a2e:370:7334
Before: 2001:db8:0:0:0:0:0:1
After: 2001:db8::1
Before: 0:0:0:0:0:0:0:1
After: ::1 (loopback)
Before: 0:0:0:0:0:0:0:0
After: :: (unspecified)
IMPORTANT: Can only use :: once per address
(otherwise ambiguous which zeros are compressed)
Special Addresses
:: Unspecified address
(0.0.0.0 in IPv4)
Used before address is configured
::1 Loopback address
(127.0.0.1 in IPv4)
Local host communication
::ffff:192.0.2.1 IPv4-mapped IPv6 address
Used for IPv4/IPv6 compatibility
Last 32 bits contain IPv4 address
2001:db8::/32 Documentation prefix
Reserved for examples (TEST-NET)
fe80::/10 Link-local prefix
Auto-configured on every interface
ff00::/8 Multicast prefix
IPv6 Address Types
1. Unicast (One-to-One)
Address for a single interface.
Global Unicast Address (GUA)
Prefix: 2000::/3 (2000:0000 to 3fff:ffff)
Routable on the Internet (like public IPv4)
Format:
| 48 bits | 16 bits | 64 bits |
| Global Routing | Subnet | Interface ID |
| Prefix | ID | |
Example:
2001:0db8:1234:0001:0000:0000:0000:0001
|-- Global --||Sub||--- Interface ID ---|
Typically:
- ISP assigns /48 or /56 to customer
- Customer has 65,536 (/48) or 256 (/56) subnets
- Each subnet is /64 with 2^64 addresses
Unique Local Address (ULA)
Prefix: fc00::/7 (fc00:: to fdff::)
Private addressing (like RFC 1918 in IPv4)
Not routed on public internet
Format:
fd00::/8 is used (fc00::/8 reserved for future)
| 8 bits | 40 bits | 16 bits | 64 bits |
| fd | Random | Subnet | Interface ID |
| prefix | Global | ID | |
| | ID | | |
Example:
fd12:3456:789a:0001::1
Generation:
- fd prefix
- 40-bit random number (cryptographically generated)
- Ensures uniqueness even if networks merge
Link-Local Address
Prefix: fe80::/10
Automatically configured on every IPv6-enabled interface
Only valid on the local link (not routed)
Like IPv4 169.254.0.0/16 (APIPA)
Format:
fe80::interface-id/64
Examples:
fe80::1
fe80::20c:29ff:fe9d:8c6a
Uses:
- Neighbor Discovery Protocol (NDP)
- Router discovery
- Address autoconfiguration
- Local communication
Always present, even if GUA configured
2. Anycast (One-to-Nearest)
Address assigned to multiple interfaces
Packet delivered to nearest one (by routing metric)
Use cases:
- Load balancing
- Service discovery
- Root DNS servers (6 of 13 use anycast)
Same format as unicast (no special prefix)
Designated as anycast during configuration
Example:
Anycast: 2001:db8::1 assigned to 3 servers
Client sends to 2001:db8::1
Routers deliver to nearest server
3. Multicast (One-to-Many)
Prefix: ff00::/8
Replaces broadcast in IPv4
Packet delivered to all members of multicast group
Format:
| 8 bits | 4 bits | 4 bits | 112 bits |
| ff | Flags | Scope | Group ID |
Flags (4 bits):
0000 = Permanent (well-known)
0001 = Temporary (transient)
Scope (4 bits):
1 = Interface-local
2 = Link-local
5 = Site-local
8 = Organization-local
e = Global
Common Multicast Addresses
Well-Known Multicast:
ff02::1 All nodes on link
(Like 255.255.255.255 broadcast)
ff02::2 All routers on link
ff02::1:2 All DHCP servers/relays on link
ff02::1:ff00:0/104 Solicited-node multicast
Used in Neighbor Discovery
ff05::1:3 All DHCP servers (site-local)
Solicited-Node Multicast:
Format: ff02::1:ff[last 24 bits of address]
Example:
Address: 2001:db8::1234:5678
Solicited-node: ff02::1:ff34:5678
Purpose: Efficient address resolution (NDP)
4. No Broadcast
IPv4 broadcast → IPv6 multicast
IPv4: 192.168.1.255 (broadcast to all)
IPv6: ff02::1 (all-nodes multicast)
Benefits:
- More efficient (only interested hosts listen)
- Reduces network noise
- Scalable
IPv6 Address Structure
EUI-64 (Extended Unique Identifier)
Method to generate interface ID from MAC address:
MAC Address: 00:1A:2B:3C:4D:5E
Step 1: Split in half
00:1A:2B : 3C:4D:5E
Step 2: Insert FF:FE in middle
00:1A:2B:FF:FE:3C:4D:5E
Step 3: Flip 7th bit (Universal/Local bit)
00 → 02 (in binary: 00000000 → 00000010)
Result: 02:1A:2B:FF:FE:3C:4D:5E
Step 4: Format as IPv6 interface ID
021a:2bff:fe3c:4d5e
Full address:
2001:db8:1234:5678:021a:2bff:fe3c:4d5e
Privacy concern: MAC address visible in IP
Solution: Privacy Extensions (RFC 4941)
Privacy Extensions (RFC 4941)
Problem: EUI-64 exposes MAC address
Allows tracking of devices
Solution: Random interface IDs
- Generated randomly
- Changed periodically (typically daily)
- Temporary addresses for outgoing connections
Example:
Stable: 2001:db8::21a:2bff:fe3c:4d5e (EUI-64, for incoming)
Temporary: 2001:db8::a4b2:76d9:3e21:91f8 (random, for outgoing)
Benefits:
- Privacy protection
- Harder to track users
- Still allows stable addressing for servers
IPv6 Subnetting
Standard Subnet Size: /64
Why /64?
1. SLAAC requires /64
- 64-bit prefix + 64-bit interface ID
2. Massive address space per subnet
- 2^64 = 18,446,744,073,709,551,616 addresses
- 18.4 quintillion addresses per subnet!
- Will never run out
3. Standard recommendation
- RFC 4291, RFC 5375
Even point-to-point links should use /64
(not /127 like IPv4 /30)
Subnet Allocation Example
ISP allocates: 2001:db8::/32
Customer (Enterprise):
Receives: 2001:db8:abcd::/48
| 32 bits | 16 bits | 16 bits | 64 bits |
| ISP Prefix | Customer| Subnet | Interface ID |
| 2001:db8 | abcd | 0-ffff | |
Customer has 2^16 = 65,536 subnets:
2001:db8:abcd:0000::/64
2001:db8:abcd:0001::/64
2001:db8:abcd:0002::/64
...
2001:db8:abcd:ffff::/64
Each subnet has 2^64 addresses
Hierarchical Addressing
Organization: 2001:db8:abcd::/48
Building 1: 2001:db8:abcd:0100::/56
Floor 1: 2001:db8:abcd:0101::/64
Floor 2: 2001:db8:abcd:0102::/64
Floor 3: 2001:db8:abcd:0103::/64
Building 2: 2001:db8:abcd:0200::/56
Floor 1: 2001:db8:abcd:0201::/64
Floor 2: 2001:db8:abcd:0202::/64
Servers: 2001:db8:abcd:1000::/56
Web: 2001:db8:abcd:1001::/64
Database: 2001:db8:abcd:1002::/64
Email: 2001:db8:abcd:1003::/64
Benefits:
- Logical organization
- Easy summarization
- Simplified routing
- Room for growth
IPv6 Auto-Configuration
SLAAC (Stateless Address Auto-Configuration)
Automatic IPv6 configuration without DHCP:
Process:
1. Link-Local Address Generation
Host creates link-local address (fe80::)
Interface ID from EUI-64 or random
2. Duplicate Address Detection (DAD)
Sends Neighbor Solicitation for its own address
If no response → address is unique
3. Router Solicitation (RS)
Host sends multicast RS to ff02::2 (all routers)
"Are there any routers?"
4. Router Advertisement (RA)
Router responds with:
- Network prefix (e.g., 2001:db8:1234:5678::/64)
- Default gateway address
- DNS servers (if configured)
- Other configuration flags
5. Global Address Formation
Host combines:
- Prefix from RA (2001:db8:1234:5678)
- Interface ID (021a:2bff:fe3c:4d5e)
- Result: 2001:db8:1234:5678:021a:2bff:fe3c:4d5e
6. DAD for Global Address
Verify global address is unique
7. Ready!
Host has link-local and global address
No DHCP server needed!
Flags in RA:
- M (Managed): Use DHCPv6 for addresses
- O (Other): Use DHCPv6 for other info (DNS, NTP, etc.)
DHCPv6 (Dynamic Host Configuration Protocol for IPv6)
Alternative/supplement to SLAAC:
Stateful DHCPv6:
- Like DHCPv4
- Server assigns addresses
- Tracks assignments
- Use when: Need centralized control
Stateless DHCPv6:
- SLAAC for address
- DHCPv6 for other info (DNS, domain, etc.)
- Use when: Need SLAAC + additional config
DHCPv6 Messages:
- SOLICIT: Client requests address
- ADVERTISE: Server offers address
- REQUEST: Client accepts offer
- REPLY: Server confirms
Multicast addresses:
- ff02::1:2 - All DHCP servers/relays on link
- ff05::1:3 - All DHCP servers (site-local)
Router Advertisement Example
Router configuration (Linux):
# Enable IPv6 forwarding
net.ipv6.conf.all.forwarding = 1
# radvd configuration
interface eth0 {
AdvSendAdvert on;
prefix 2001:db8:1234:5678::/64 {
AdvOnLink on;
AdvAutonomous on;
};
RDNSS 2001:4860:4860::8888 {
};
};
This advertises:
- Prefix: 2001:db8:1234:5678::/64
- DNS: 2001:4860:4860::8888 (Google DNS)
- Clients auto-configure themselves
Neighbor Discovery Protocol (NDP)
NDP replaces ARP and adds functionality:
NDP Functions
1. Router Discovery
- Find routers on link
- Get network prefix
2. Address Resolution
- Map IPv6 address to MAC address
- Replaces ARP
3. Duplicate Address Detection (DAD)
- Verify address uniqueness
4. Neighbor Unreachability Detection
- Monitor neighbor reachability
5. Redirect
- Inform hosts of better next hop
NDP Message Types (ICMPv6)
Type 133: Router Solicitation (RS)
Sent by: Host
To: ff02::2 (all routers)
Purpose: "Are there routers here?"
Type 134: Router Advertisement (RA)
Sent by: Router
To: ff02::1 (all nodes) or unicast
Purpose: "Here's my prefix and config"
Type 135: Neighbor Solicitation (NS)
Sent by: Host
To: Solicited-node multicast
Purpose: "Who has this IPv6 address?" (like ARP request)
"Is anyone using this address?" (DAD)
Type 136: Neighbor Advertisement (NA)
Sent by: Host
To: Unicast or ff02::1
Purpose: "I have this IPv6 address" (like ARP reply)
"I'm using this address" (DAD response)
Type 137: Redirect
Sent by: Router
To: Unicast (specific host)
Purpose: "Use different router for this destination"
Address Resolution Example
Host A wants to communicate with Host B:
Host A: 2001:db8::1
Host B: 2001:db8::2
Host B MAC: 00:11:22:33:44:55
1. Host A sends Neighbor Solicitation (NS):
From: 2001:db8::1
To: ff02::1:ff00:2 (solicited-node multicast for ::2)
Question: "What's the MAC address of 2001:db8::2?"
2. Host B receives NS (listening on solicited-node multicast)
3. Host B sends Neighbor Advertisement (NA):
From: 2001:db8::2
To: 2001:db8::1 (unicast reply)
Answer: "My MAC is 00:11:22:33:44:55"
4. Host A caches: 2001:db8::2 → 00:11:22:33:44:55
5. Host A sends packet directly to Host B
Neighbor cache entry:
2001:db8::2 dev eth0 lladdr 00:11:22:33:44:55 REACHABLE
Duplicate Address Detection (DAD)
Before using any address:
1. Node creates address:
- Link-local: fe80::1
- Or global: 2001:db8::1
2. Node sends Neighbor Solicitation:
From: :: (unspecified address)
To: ff02::1:ff00:1 (solicited-node multicast)
Target: fe80::1 (address being tested)
Question: "Is anyone using fe80::1?"
3. Wait 1 second:
- If NA received → Address in use (conflict!)
- If no response → Address is unique ✓
4. If unique:
- Mark address as valid
- Start using it
If conflict detected:
- Link-local conflict: Generate new interface ID
- Global conflict: Manual intervention required
IPv6 Extension Headers
Extension headers provide optional functionality without bloating main header:
Extension Header Types
Next Header values:
0 = Hop-by-Hop Options (must be first if present)
43 = Routing Header
44 = Fragment Header
50 = Encapsulating Security Payload (ESP)
51 = Authentication Header (AH)
60 = Destination Options
59 = No Next Header (no more headers)
6 = TCP
17 = UDP
58 = ICMPv6
Extension Header Chaining
Base IPv6 Header
Next Header = 43 (Routing)
↓
Routing Header
Next Header = 44 (Fragment)
↓
Fragment Header
Next Header = 60 (Destination Options)
↓
Destination Options Header
Next Header = 6 (TCP)
↓
TCP Header and Data
Recommended order (RFC 2460):
1. IPv6 base header
2. Hop-by-Hop Options
3. Destination Options (for intermediate destinations)
4. Routing
5. Fragment
6. Authentication (AH)
7. Encapsulating Security Payload (ESP)
8. Destination Options (for final destination)
9. Upper layer (TCP, UDP, ICMPv6, etc.)
Fragment Header
Fragmentation only at source (not routers!)
Format:
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Next Header | Reserved | Fragment Offset |Res|M|
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Identification |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Next Header: Protocol after reassembly
Fragment Offset: Position in original packet (8-byte units)
M flag: More Fragments (1 = more, 0 = last)
Identification: Groups fragments together
Process:
1. Source tests path MTU
2. If packet > MTU, source fragments
3. Router that cannot forward sends ICMPv6 "Packet Too Big"
4. Source reduces packet size or fragments
5. Destination reassembles
Note: Routers never fragment!
ICMPv6 (Internet Control Message Protocol for IPv6)
ICMPv6 is integral to IPv6 operation:
ICMPv6 Message Types
Error Messages:
1 Destination Unreachable
Code 0: No route to destination
Code 1: Communication with destination administratively prohibited
Code 3: Address unreachable
Code 4: Port unreachable
2 Packet Too Big
Used for Path MTU Discovery
Includes MTU of next hop
3 Time Exceeded
Code 0: Hop limit exceeded in transit
Code 1: Fragment reassembly time exceeded
4 Parameter Problem
Code 0: Erroneous header field
Code 1: Unrecognized Next Header type
Informational Messages:
128 Echo Request (ping)
129 Echo Reply (ping response)
Neighbor Discovery (part of ICMPv6):
133 Router Solicitation
134 Router Advertisement
135 Neighbor Solicitation
136 Neighbor Advertisement
137 Redirect
Multicast Listener Discovery:
130 Multicast Listener Query
131 Multicast Listener Report
132 Multicast Listener Done
Ping6 Example
# Basic ping
ping6 2001:4860:4860::8888
# Ping link-local (must specify interface)
ping6 fe80::1%eth0
# Set packet size
ping6 -s 1000 2001:4860:4860::8888
# Set hop limit
ping6 -t 5 2001:4860:4860::8888
# Example output:
PING 2001:4860:4860::8888(2001:4860:4860::8888) 56 data bytes
64 bytes from 2001:4860:4860::8888: icmp_seq=1 ttl=118 time=10.2 ms
64 bytes from 2001:4860:4860::8888: icmp_seq=2 ttl=118 time=9.9 ms
--- 2001:4860:4860::8888 ping statistics ---
2 packets transmitted, 2 received, 0% packet loss, time 1001ms
rtt min/avg/max/mdev = 9.900/10.050/10.200/0.150 ms
Path MTU Discovery
IPv6 requires source to fragment:
1. Source sends large packet (1500 bytes)
2. Router with smaller MTU (1400 bytes):
- Cannot fragment (not allowed in IPv6)
- Drops packet
- Sends ICMPv6 Type 2 "Packet Too Big"
- Includes MTU value (1400)
3. Source receives ICMPv6:
- Reduces packet size to 1400
- Retransmits
4. Success!
- Source caches PMTU for destination
- Uses smaller packets for this destination
Benefits:
- No fragmentation overhead at routers
- Better performance
- Source controls fragmentation
IPv6 Commands and Tools
IPv6 Configuration (Linux)
# View IPv6 addresses
ip -6 addr show
ip -6 a
# Add IPv6 address
sudo ip -6 addr add 2001:db8::1/64 dev eth0
# Remove IPv6 address
sudo ip -6 addr del 2001:db8::1/64 dev eth0
# Enable IPv6 on interface
sudo sysctl -w net.ipv6.conf.eth0.disable_ipv6=0
# Disable IPv6 on interface
sudo sysctl -w net.ipv6.conf.eth0.disable_ipv6=1
# View IPv6 routing table
ip -6 route show
# Add IPv6 route
sudo ip -6 route add 2001:db8::/32 via 2001:db8::1
# Add default route
sudo ip -6 route add default via fe80::1 dev eth0
# View neighbor cache (NDP)
ip -6 neigh show
IPv6 Configuration (Windows)
# View IPv6 configuration
netsh interface ipv6 show config
ipconfig
# Add IPv6 address
netsh interface ipv6 add address "Ethernet" 2001:db8::1/64
# Remove IPv6 address
netsh interface ipv6 delete address "Ethernet" 2001:db8::1
# Add route
netsh interface ipv6 add route 2001:db8::/32 "Ethernet" 2001:db8::1
# View IPv6 routing table
netsh interface ipv6 show route
route print -6
# View neighbor cache
netsh interface ipv6 show neighbors
Testing Connectivity
# Ping IPv6 address
ping6 2001:4860:4860::8888
ping -6 google.com
# Ping link-local (requires interface specification)
ping6 fe80::1%eth0
ping6 -I eth0 fe80::1
# Traceroute
traceroute6 google.com
traceroute -6 google.com
# TCP connection test
telnet 2001:4860:4860::8888 80
nc -6 google.com 80
# DNS lookup
host google.com
dig AAAA google.com
nslookup -type=AAAA google.com
Network Diagnostics
# View IPv6 sockets
ss -6 -tuln
netstat -6 -tuln
# View IPv6 connections
ss -6 -tun
netstat -6 -tun
# tcpdump for IPv6
sudo tcpdump -i eth0 ip6
sudo tcpdump -i eth0 'icmp6'
sudo tcpdump -i eth0 'ip6 and tcp port 80'
# Neighbor Discovery monitoring
sudo tcpdump -i eth0 'icmp6 and (ip6[40] >= 133 and ip6[40] <= 137)'
IPv6 Best Practices
1. Address Planning
Use /48 for sites:
- Gives 65,536 subnets
- Future-proof
- Standard recommendation
Use /64 for subnets:
- Required for SLAAC
- Standard LAN size
- Even for point-to-point
Use /56 for small sites:
- 256 subnets
- Acceptable for small deployments
Hierarchy example:
2001:db8:abcd::/48 Organization
2001:db8:abcd:0100::/56 Building 1
2001:db8:abcd:0101::/64 Floor 1
2001:db8:abcd:0102::/64 Floor 2
2001:db8:abcd:0200::/56 Building 2
2001:db8:abcd:1000::/56 Data center
2. Dual Stack
Run IPv4 and IPv6 simultaneously:
Benefits:
- Smooth transition
- Backward compatibility
- No disruption
Implementation:
- Enable IPv6 on all interfaces
- Maintain IPv4 for legacy
- Configure both protocols on servers
- Use DNS with A and AAAA records
Eventually:
- IPv6-only for new deployments
- IPv4 only where necessary
3. Security
IPv6-specific security considerations:
1. ICMPv6 is essential
- Don't block all ICMPv6
- Allow NDP (types 133-137)
- Allow PMTU Discovery (type 2)
- Allow Echo Request/Reply (types 128-129)
2. Link-local security
- fe80::/10 should stay local
- Don't route link-local
3. Disable IPv6 if not using
- But preferably, enable and secure it
- Attacks can use IPv6 if enabled but unmonitored
4. RA Guard
- Prevent rogue router advertisements
- Protect against MITM attacks
5. Extension headers
- Many firewalls can't inspect them
- Consider filtering or limiting
6. Privacy Extensions
- Enable for client devices
- Prevents tracking via EUI-64
4. DNS Configuration
Always configure both records:
example.com. IN A 192.0.2.1 (IPv4)
example.com. IN AAAA 2001:db8::1 (IPv6)
Test both:
dig A example.com
dig AAAA example.com
Reverse DNS:
IPv4: 1.2.0.192.in-addr.arpa
IPv6: 1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa
Configure resolver:
/etc/resolv.conf:
nameserver 2001:4860:4860::8888
nameserver 2001:4860:4860::8844
nameserver 8.8.8.8
5. Firewalling
# ip6tables example
# Allow established connections
ip6tables -A INPUT -m state --state ESTABLISHED,RELATED -j ACCEPT
# Allow loopback
ip6tables -A INPUT -i lo -j ACCEPT
# Allow ICMPv6
ip6tables -A INPUT -p ipv6-icmp -j ACCEPT
# Allow SSH
ip6tables -A INPUT -p tcp --dport 22 -j ACCEPT
# Allow HTTP/HTTPS
ip6tables -A INPUT -p tcp --dport 80 -j ACCEPT
ip6tables -A INPUT -p tcp --dport 443 -j ACCEPT
# Drop invalid packets
ip6tables -A INPUT -m state --state INVALID -j DROP
# Default drop
ip6tables -P INPUT DROP
ip6tables -P FORWARD DROP
ip6tables -P OUTPUT ACCEPT
6. Monitoring
# Monitor NDP
ip -6 neigh show
watch -n 1 'ip -6 neigh show'
# Monitor IPv6 traffic
sudo iftop -f "ip6"
sudo nethogs -6
# View IPv6 statistics
netstat -s -6
# Monitor routing
ip -6 route show
watch -n 1 'ip -6 route show'
# Check for IPv6 connectivity
ping6 -c 1 2001:4860:4860::8888 && echo "IPv6 works" || echo "IPv6 fails"
IPv6 Transition Mechanisms
1. Dual Stack
Run both IPv4 and IPv6:
Advantages:
+ Simple
+ No translation
+ Native performance
Disadvantages:
- Must manage both protocols
- Requires IPv4 addresses (scarce)
Best for: Long-term transition
2. Tunneling
6in4 (Manual Tunnel)
IPv6 packets encapsulated in IPv4:
[IPv4 Header][IPv6 Header][Data]
Configuration:
# Linux
ip tunnel add ipv6tunnel mode sit remote 198.51.100.1 local 192.0.2.1
ip link set ipv6tunnel up
ip addr add 2001:db8::2/64 dev ipv6tunnel
ip route add ::/0 dev ipv6tunnel
Use case: Static IPv6 over IPv4
6to4
Automatic tunneling using 2002::/16:
IPv4: 192.0.2.1
IPv6: 2002:c000:0201::/48
(c000:0201 = 192.0.2.1 in hex)
Deprecated: Security issues
Teredo
Tunneling for NAT environments:
Prefix: 2001::/32
Use case: Windows clients behind NAT
Status: Deprecated, use native IPv6
3. NAT64/DNS64
Allow IPv6-only clients to access IPv4 services:
IPv6 client (2001:db8::1)
↓ Request "www.example.com"
DNS64 server
↓ Returns 64:ff9b::192.0.2.1 (synthesized AAAA)
IPv6 client
↓ Connects to 64:ff9b::192.0.2.1
NAT64 gateway
↓ Translates to IPv4: 192.0.2.1
IPv4 server (192.0.2.1)
Use case: IPv6-only networks accessing IPv4 internet
IPv6 vs IPv4 Comparison
| Feature | IPv4 | IPv6 |
|---|---|---|
| Address length | 32 bits | 128 bits |
| Address format | Decimal (192.0.2.1) | Hexadecimal (2001:db8::1) |
| Address space | 4.3 billion | 340 undecillion |
| Header size | 20-60 bytes (variable) | 40 bytes (fixed) |
| Checksum | Yes | No |
| Fragmentation | Routers and source | Source only |
| Broadcast | Yes | No (multicast) |
| Multicast | Optional | Built-in |
| IPSec | Optional | Mandatory |
| Address resolution | ARP | NDP |
| Auto-configuration | DHCP | SLAAC or DHCPv6 |
| NAT | Common | Not needed |
| Options | In header | Extension headers |
| Jumbograms | No | Yes (>65535 bytes) |
| Mobile IP | Extension | Built-in |
ELI10: IPv6 Explained Simply
Think of IPv6 as a massive upgrade to the internet's addressing system:
The Address Problem
IPv4 (old):
- Like phone numbers with 10 digits
- Only 4.3 billion addresses
- Running out (like phone numbers in 1990s)
IPv6 (new):
- Like phone numbers with 39 digits
- 340 undecillion addresses
- Enough for every atom on Earth to have trillions of IPs
- We'll NEVER run out
Address Format
IPv4: 192.168.1.1
- Four numbers (0-255)
- Separated by dots
IPv6: 2001:db8::1
- Eight groups of hex digits (0-9, a-f)
- Separated by colons
- Can compress zeros with ::
Auto-Configuration
IPv4:
- Need DHCP server
- Manual configuration for servers
- "Hey DHCP, give me an address!"
IPv6:
- Auto-configures itself (SLAAC)
- Listens for router
- Creates own address
- "I'll make my own address, thanks!"
No More NAT
IPv4 with NAT:
Home: All devices share one public IP
Like apartment building with one mailbox
IPv6:
Home: Every device gets its own public IP
Like every apartment having its own mailbox
Direct delivery, no sharing needed
Better Security
IPv4:
- Security added later (IPSec optional)
- Like adding locks to old houses
IPv6:
- Security built-in (IPSec mandatory)
- Like new houses with locks included
Link-Local Addresses
Every IPv6 device has:
1. Link-local (fe80::): For local network (like intercom)
2. Global (2001:...): For internet (like phone number)
Always have both, automatic!
Further Resources
- RFC 8200 - IPv6 Specification
- RFC 4291 - IPv6 Addressing Architecture
- RFC 4862 - IPv6 Stateless Address Autoconfiguration
- RFC 4861 - Neighbor Discovery for IPv6
- RFC 4941 - Privacy Extensions for SLAAC
- RFC 3484 - Default Address Selection
- IPv6 Test - Test your IPv6 connectivity
- Hurricane Electric IPv6 Certification - Free IPv6 training
TCP (Transmission Control Protocol)
TCP is a connection-oriented, reliable transport layer protocol that provides ordered delivery of data between applications running on hosts in an IP network. It is one of the core protocols of the Internet Protocol Suite.
Key Features
- Connection-Oriented: Establishes connection before data transfer
- Reliable: Guarantees delivery of data in order
- Error Checking: Detects corrupted data with checksums
- Flow Control: Manages data transmission rate
- Congestion Control: Adjusts to network conditions
- Full-Duplex: Bidirectional communication
TCP Packet Format
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Port | Destination Port |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Sequence Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Acknowledgment Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Data | |C|E|U|A|P|R|S|F| |
| Offset| Rsrvd |W|C|R|C|S|S|Y|I| Window |
| | |R|E|G|K|H|T|N|N| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Checksum | Urgent Pointer |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Options | Padding |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Data |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Header Fields
- Source Port (16 bits): Sending application port number
- Destination Port (16 bits): Receiving application port number
- Sequence Number (32 bits): Position of first data byte in segment
- Acknowledgment Number (32 bits): Next expected sequence number
- Data Offset (4 bits): Size of TCP header in 32-bit words
- Reserved (3 bits): Reserved for future use
- Control Flags (9 bits): Connection control flags
- Window Size (16 bits): Receive window size
- Checksum (16 bits): Error detection
- Urgent Pointer (16 bits): Offset of urgent data
- Options (variable): Optional header extensions
- Padding: Ensures header is multiple of 32 bits
Control Flags
- CWR (Congestion Window Reduced): ECN-Echo flag received
- ECE (ECN-Echo): Congestion experienced
- URG (Urgent): Urgent pointer field is valid
- ACK (Acknowledgment): Acknowledgment number is valid
- PSH (Push): Push buffered data to application
- RST (Reset): Reset the connection
- SYN (Synchronize): Synchronize sequence numbers (connection setup)
- FIN (Finish): No more data from sender (connection termination)
Three-Way Handshake
TCP uses a three-way handshake to establish a connection:
Client Server
| |
| SYN (seq=x) |
|-------------------------------------->|
| |
| SYN-ACK (seq=y, ack=x+1) |
|<--------------------------------------|
| |
| ACK (seq=x+1, ack=y+1) |
|-------------------------------------->|
| |
| Connection Established |
| |
- SYN: Client sends SYN packet with initial sequence number
- SYN-ACK: Server responds with SYN-ACK, includes its own sequence number
- ACK: Client sends ACK to confirm, connection established
Python Example: TCP Client
import socket
# Create TCP socket
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Connect to server (three-way handshake happens here)
server_address = ('localhost', 8080)
client_socket.connect(server_address)
print(f"Connected to {server_address}")
# Send data
message = "Hello, Server!"
client_socket.sendall(message.encode())
# Receive response
response = client_socket.recv(1024)
print(f"Received: {response.decode()}")
# Close connection
client_socket.close()
Python Example: TCP Server
import socket
# Create TCP socket
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Bind to address and port
server_address = ('localhost', 8080)
server_socket.bind(server_address)
# Listen for connections
server_socket.listen(5)
print(f"Server listening on {server_address}")
while True:
# Accept connection (completes three-way handshake)
client_socket, client_address = server_socket.accept()
print(f"Connection from {client_address}")
try:
# Receive data
data = client_socket.recv(1024)
print(f"Received: {data.decode()}")
# Send response
response = "Hello, Client!"
client_socket.sendall(response.encode())
finally:
# Close connection
client_socket.close()
Connection Termination
TCP uses a four-way handshake to close a connection gracefully:
Client Server
| |
| FIN (seq=x) |
|-------------------------------------->|
| |
| ACK (ack=x+1) |
|<--------------------------------------|
| |
| FIN (seq=y) |
|<--------------------------------------|
| |
| ACK (ack=y+1) |
|-------------------------------------->|
| |
| Connection Closed |
- FIN: Active closer sends FIN
- ACK: Passive closer acknowledges FIN
- FIN: Passive closer sends its FIN
- ACK: Active closer acknowledges FIN
TCP State Machine
CLOSED
|
| (active open/SYN)
v
SYN-SENT
|
| (SYN received/SYN-ACK sent)
v
SYN-RECEIVED
|
| (ACK received)
v
ESTABLISHED
|
| (close/FIN sent)
v
FIN-WAIT-1
|
| (ACK received)
v
FIN-WAIT-2
|
| (FIN received/ACK sent)
v
TIME-WAIT
|
| (2*MSL timeout)
v
CLOSED
TCP States
- CLOSED: No connection
- LISTEN: Server waiting for connection request
- SYN-SENT: Client sent SYN, waiting for SYN-ACK
- SYN-RECEIVED: Server received SYN, sent SYN-ACK
- ESTABLISHED: Connection established, data transfer
- FIN-WAIT-1: Sent FIN, waiting for ACK
- FIN-WAIT-2: Received ACK of FIN, waiting for peer FIN
- CLOSE-WAIT: Received FIN, waiting for close
- CLOSING: Both sides sent FIN simultaneously
- LAST-ACK: Waiting for final ACK
- TIME-WAIT: Waiting to ensure remote received ACK
- CLOSED: Connection fully terminated
Check Connection States
# Linux - Show all TCP connections
netstat -tan
# Show listening ports
netstat -tln
# Show established connections
netstat -tan | grep ESTABLISHED
# Alternative: ss command (faster)
ss -tan
ss -tln
ss -tan state established
# Show connection state for specific port
ss -tan '( dport = :80 or sport = :80 )'
Flow Control
TCP uses a sliding window protocol for flow control:
import socket
import time
def tcp_receiver_with_flow_control():
"""
Receiver controls flow using window size
"""
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.bind(('localhost', 8080))
server.listen(1)
conn, addr = server.accept()
# Set receive buffer size (affects window size)
conn.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 4096)
total_received = 0
while True:
data = conn.recv(1024)
if not data:
break
total_received += len(data)
print(f"Received {len(data)} bytes, total: {total_received}")
# Simulate slow processing
time.sleep(0.1)
conn.close()
server.close()
def tcp_sender():
"""
Sender adapts to receiver's window size
"""
client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client.connect(('localhost', 8080))
# Set send buffer size
client.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 8192)
# Send large amount of data
data = b'X' * 100000
sent = 0
while sent < len(data):
chunk = data[sent:sent+1024]
try:
bytes_sent = client.send(chunk)
sent += bytes_sent
print(f"Sent {bytes_sent} bytes, total: {sent}")
except socket.error as e:
print(f"Send error: {e}")
break
client.close()
Congestion Control
TCP implements congestion control algorithms:
Algorithms
- Slow Start: Exponentially increase congestion window
- Congestion Avoidance: Linearly increase window
- Fast Retransmit: Retransmit on 3 duplicate ACKs
- Fast Recovery: Reduce window, avoid slow start
Window Size
^
| Slow Start | Congestion Avoidance
| /|
| / |
| / |_______________
| / | \
| / | \
| / | \ Fast Recovery
| / | \_______________
| / |
|/________________|________________________> Time
Threshold
Check TCP Congestion Control
# Linux - Check current algorithm
sysctl net.ipv4.tcp_congestion_control
# Available algorithms
sysctl net.ipv4.tcp_available_congestion_control
# Set congestion control algorithm
sudo sysctl -w net.ipv4.tcp_congestion_control=cubic
# Common algorithms:
# - cubic (default on most Linux)
# - reno (traditional)
# - bbr (Google's BBR)
# - vegas
Retransmission
TCP retransmits lost or corrupted packets:
import socket
import time
def tcp_with_timeout():
"""
TCP automatically handles retransmission
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Set timeout for operations
sock.settimeout(5.0)
try:
sock.connect(('example.com', 80))
# Send HTTP request
request = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"
sock.sendall(request)
# Receive response
response = sock.recv(4096)
print(f"Received {len(response)} bytes")
except socket.timeout:
print("Operation timed out - TCP retransmission may be occurring")
except socket.error as e:
print(f"Socket error: {e}")
finally:
sock.close()
Retransmission Timeout (RTO)
# Linux - View TCP retransmission statistics
netstat -s | grep -i retrans
# Check retransmission timer settings
sysctl net.ipv4.tcp_retries1 # Threshold for alerting
sysctl net.ipv4.tcp_retries2 # Maximum retries before giving up
# Typical values:
# tcp_retries1 = 3 (alert after 3-6 seconds)
# tcp_retries2 = 15 (give up after ~13-30 minutes)
TCP Options
Common TCP options in the header:
Maximum Segment Size (MSS)
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Set TCP_MAXSEG option (MSS)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_MAXSEG, 1400)
# Get current MSS
mss = sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_MAXSEG)
print(f"TCP MSS: {mss}")
Window Scaling
# Enable TCP window scaling (Linux)
sudo sysctl -w net.ipv4.tcp_window_scaling=1
# Check current setting
sysctl net.ipv4.tcp_window_scaling
Selective Acknowledgment (SACK)
# Enable SACK (Linux)
sudo sysctl -w net.ipv4.tcp_sack=1
# Check current setting
sysctl net.ipv4.tcp_sack
Timestamps
# Enable TCP timestamps
sudo sysctl -w net.ipv4.tcp_timestamps=1
# Check current setting
sysctl net.ipv4.tcp_timestamps
TCP vs UDP
| Feature | TCP | UDP |
|---|---|---|
| Connection | Connection-oriented | Connectionless |
| Reliability | Guaranteed delivery | No guarantee |
| Ordering | In-order delivery | No ordering |
| Speed | Slower (overhead) | Faster (minimal overhead) |
| Header Size | 20-60 bytes | 8 bytes |
| Error Checking | Yes (checksum) | Yes (checksum) |
| Flow Control | Yes | No |
| Congestion Control | Yes | No |
| Use Cases | HTTP, FTP, SSH, Email | DNS, VoIP, Streaming, Gaming |
When to Use TCP
- File transfers
- Web browsing
- Remote shell (SSH)
- Any application requiring reliability
When to Use UDP
- Real-time applications (VoIP, video streaming)
- DNS queries
- Online gaming
- IoT devices with small data
- Broadcasting/multicasting
Performance Tuning
Socket Buffer Sizes
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Increase buffer sizes for high-throughput applications
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 1024) # 1MB receive
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 1024) # 1MB send
# Get buffer sizes
rcvbuf = sock.getsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF)
sndbuf = sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF)
print(f"Receive buffer: {rcvbuf}, Send buffer: {sndbuf}")
TCP Keepalive
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Enable keepalive
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
# Set keepalive parameters (Linux)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 60) # Start after 60s
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 10) # Interval 10s
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 3) # Retry 3 times
Nagle's Algorithm
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Disable Nagle's algorithm for low-latency applications
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# Check status
nodelay = sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
print(f"TCP_NODELAY: {nodelay}")
Linux Kernel Tuning
# Increase maximum buffer sizes
sudo sysctl -w net.core.rmem_max=16777216
sudo sysctl -w net.core.wmem_max=16777216
# Set TCP buffer sizes (min, default, max)
sudo sysctl -w net.ipv4.tcp_rmem="4096 87380 16777216"
sudo sysctl -w net.ipv4.tcp_wmem="4096 65536 16777216"
# Increase backlog queue
sudo sysctl -w net.core.somaxconn=1024
sudo sysctl -w net.ipv4.tcp_max_syn_backlog=2048
# Enable TCP Fast Open
sudo sysctl -w net.ipv4.tcp_fastopen=3
# Reuse TIME_WAIT sockets
sudo sysctl -w net.ipv4.tcp_tw_reuse=1
Troubleshooting
Analyze TCP with tcpdump
# Capture TCP traffic on port 80
sudo tcpdump -i any tcp port 80 -n
# Capture SYN packets
sudo tcpdump 'tcp[tcpflags] & (tcp-syn) != 0' -n
# Capture RST packets
sudo tcpdump 'tcp[tcpflags] & (tcp-rst) != 0' -n
# Save to file for analysis
sudo tcpdump -i any tcp port 80 -w capture.pcap
# Read from file
tcpdump -r capture.pcap -n
Analyze with Wireshark
# Start Wireshark
wireshark
# Useful display filters:
# tcp.port == 80
# tcp.flags.syn == 1
# tcp.flags.reset == 1
# tcp.analysis.retransmission
# tcp.analysis.duplicate_ack
# tcp.window_size_value < 1000
Common Issues
Connection Refused
# Check if port is listening
netstat -tln | grep :80
# Check firewall
sudo iptables -L -n | grep 80
Connection Timeout
# Test connectivity
telnet example.com 80
# Check routing
traceroute example.com
# Test with timeout
timeout 5 telnet example.com 80
Slow Connection
import socket
import time
def measure_tcp_performance():
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Measure connection time
start = time.time()
sock.connect(('example.com', 80))
connect_time = time.time() - start
print(f"Connection time: {connect_time:.3f}s")
# Send request
request = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"
start = time.time()
sock.sendall(request)
send_time = time.time() - start
print(f"Send time: {send_time:.3f}s")
# Receive response
start = time.time()
data = sock.recv(4096)
recv_time = time.time() - start
print(f"Receive time: {recv_time:.3f}s")
print(f"Received {len(data)} bytes")
sock.close()
measure_tcp_performance()
Monitoring TCP Connections
# Real-time connection monitoring
watch -n 1 'netstat -tan | grep ESTABLISHED | wc -l'
# Connection state distribution
netstat -tan | awk '{print $6}' | sort | uniq -c
# Show connections with process info
sudo netstat -tanp
# Alternative with ss
ss -tanp state established
Advanced Topics
TCP Fast Open (TFO)
Reduces latency by sending data in SYN packet:
import socket
# Client with TFO
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Enable TFO (requires kernel support)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_FASTOPEN, 1)
# Send data during connection (SYN packet)
data = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"
sock.sendto(data, socket.MSG_FASTOPEN, ('example.com', 80))
TCP Multipath (MPTCP)
Allows connection over multiple paths:
# Check if MPTCP is available (Linux)
sysctl net.mptcp.enabled
# Enable MPTCP
sudo sysctl -w net.mptcp.enabled=1
Zero Copy
Improve performance with zero-copy operations:
import socket
import os
def sendfile_example(sock, filename):
"""
Send file using zero-copy sendfile
"""
with open(filename, 'rb') as f:
# Get file size
file_size = os.fstat(f.fileno()).st_size
# Send file using sendfile (zero-copy)
offset = 0
while offset < file_size:
sent = os.sendfile(sock.fileno(), f.fileno(), offset, file_size - offset)
offset += sent
Best Practices
- Always close sockets: Use try-finally or context managers
- Set appropriate timeouts: Avoid hanging indefinitely
- Handle errors gracefully: Network can fail at any time
- Use connection pooling: Reuse connections for better performance
- Enable keepalive for long connections: Detect dead connections
- Tune buffer sizes for workload: Larger for throughput, smaller for latency
- Monitor connection states: Watch for TIME_WAIT buildup
- Use TCP_NODELAY for interactive apps: Reduce latency
- Enable window scaling for high-bandwidth: Support larger windows
- Test under load: Verify behavior under stress
References
- RFC 793 - TCP Specification
- RFC 1323 - TCP Extensions (Window Scaling, Timestamps)
- RFC 2018 - TCP Selective Acknowledgment
- RFC 7413 - TCP Fast Open
- RFC 8684 - Multipath TCP
UDP (User Datagram Protocol)
Overview
UDP is a connectionless transport layer protocol that provides fast, unreliable data transmission. Unlike TCP, UDP doesn't guarantee delivery, ordering, or error checking, making it ideal for time-sensitive applications where speed matters more than reliability.
UDP vs TCP
| Feature | UDP | TCP |
|---|---|---|
| Connection | Connectionless | Connection-oriented |
| Reliability | Unreliable (no guarantee) | Reliable (guaranteed delivery) |
| Ordering | No ordering | Ordered delivery |
| Speed | Fast (low overhead) | Slower (more overhead) |
| Header Size | 8 bytes | 20-60 bytes |
| Error Checking | Optional checksum | Mandatory checksum + retransmission |
| Flow Control | None | Yes (window-based) |
| Congestion Control | None | Yes |
| Use Cases | Streaming, gaming, DNS, VoIP | File transfer, web, email |
UDP Packet Format
0 7 8 15 16 23 24 31
+--------+--------+--------+--------+
| Source | Destination |
| Port | Port |
+--------+--------+--------+--------+
| | |
| Length | Checksum |
+--------+--------+--------+--------+
| |
| Data octets ... |
+-----------------------------------+
Header Fields (8 bytes total)
- Source Port (16 bits): Port number of sender (optional, can be 0)
- Destination Port (16 bits): Port number of receiver
- Length (16 bits): Length of UDP header + data (minimum 8 bytes)
- Checksum (16 bits): Error checking (optional in IPv4, mandatory in IPv6)
Example UDP Header
Source Port: 53210 (0xCFCA)
Destination Port: 53 (0x0035) - DNS
Length: 512 bytes
Checksum: 0x1A2B
Hexadecimal representation:
CF CA 00 35 02 00 1A 2B
[... 504 bytes of data ...]
How UDP Works
Sending Data
Application → Socket → UDP Layer → IP Layer → Network
1. Application writes data to UDP socket
2. UDP adds 8-byte header
3. UDP passes datagram to IP layer
4. IP sends packet to destination
5. No acknowledgment expected
Receiving Data
Network → IP Layer → UDP Layer → Socket → Application
1. IP receives packet
2. IP passes to UDP based on protocol number (17)
3. UDP validates checksum (if present)
4. UDP delivers to application based on port
5. If port not listening, send ICMP "Port Unreachable"
UDP Communication Flow
One-Way Communication (Fire and Forget)
Client Server (port 9000)
| |
| UDP Packet (Hello) |
|------------------------------->|
| |
| UDP Packet (World) |
|------------------------------->|
| |
No handshake, no acknowledgment
Two-Way Communication (Request-Response)
Client Server
| |
| DNS Query (Port 53) |
|------------------------------->|
| |
| DNS Response |
|<-------------------------------|
| |
Application must handle timeouts and retries
UDP Checksum Calculation
Pseudo Header (for checksum calculation)
+--------+--------+--------+--------+
| Source IP Address |
+--------+--------+--------+--------+
| Destination IP Address |
+--------+--------+--------+--------+
| zero |Protocol| UDP Length |
+--------+--------+--------+--------+
Checksum Process
- Create pseudo header from IP information
- Concatenate: Pseudo header + UDP header + data
- Divide into 16-bit words
- Sum all 16-bit words
- Add carry bits to result
- Take one's complement
Example:
def calculate_checksum(data):
# Sum all 16-bit words
total = sum(struct.unpack("!%dH" % (len(data)//2), data))
# Add carry
total = (total >> 16) + (total & 0xffff)
total += (total >> 16)
# One's complement
return ~total & 0xffff
Common UDP Ports
| Port | Service | Purpose |
|---|---|---|
| 53 | DNS | Domain name resolution |
| 67/68 | DHCP | Dynamic IP configuration |
| 69 | TFTP | Trivial File Transfer |
| 123 | NTP | Network Time Protocol |
| 161/162 | SNMP | Network management |
| 514 | Syslog | System logging |
| 520 | RIP | Routing protocol |
| 1900 | SSDP | Service discovery (UPnP) |
| 3478 | STUN | NAT traversal |
| 5353 | mDNS | Multicast DNS |
UDP Use Cases
1. DNS (Domain Name System)
Client sends UDP query to port 53:
+----------------+
| DNS Query |
| example.com? |
+----------------+
Server responds:
+----------------+
| DNS Response |
| 93.184.216.34 |
+----------------+
Fast lookup, retry if no response
2. Video Streaming
Server sends video frames continuously:
Frame 1 → Frame 2 → Frame 3 → Frame 4 → Frame 5
If Frame 3 is lost, continue with Frame 4
(Old frame is useless for live streaming)
3. Online Gaming
Game Client → Server: Player position updates (60 FPS)
Update 1: Player at (100, 200)
Update 2: Player at (101, 201)
Update 3: [LOST]
Update 4: Player at (103, 203)
Lost packet is okay - next update corrects position
4. VoIP (Voice over IP)
Continuous audio stream:
Packet 1: Audio 0-20ms
Packet 2: Audio 20-40ms
Packet 3: Audio 40-60ms [LOST]
Packet 4: Audio 60-80ms
Lost packet = brief audio glitch
Retransmission would cause worse delay
5. DHCP (IP Address Assignment)
Client Server
| |
| DHCP Discover (broadcast) |
|------------------------------->|
| |
| DHCP Offer |
|<-------------------------------|
| |
| DHCP Request |
|------------------------------->|
| |
| DHCP ACK |
|<-------------------------------|
UDP Socket Programming
Python UDP Server
import socket
# Create UDP socket
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# Bind to address and port
server_address = ('localhost', 9000)
sock.bind(server_address)
print(f"UDP server listening on {server_address}")
while True:
# Receive data (up to 1024 bytes)
data, client_address = sock.recvfrom(1024)
print(f"Received {len(data)} bytes from {client_address}")
print(f"Data: {data.decode()}")
# Send response
sock.sendto(b"Message received", client_address)
Python UDP Client
import socket
# Create UDP socket
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
server_address = ('localhost', 9000)
try:
# Send data
message = b"Hello, UDP Server!"
sock.sendto(message, server_address)
# Receive response (with timeout)
sock.settimeout(5.0)
data, server = sock.recvfrom(1024)
print(f"Received: {data.decode()}")
except socket.timeout:
print("No response from server")
finally:
sock.close()
C UDP Server
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#define PORT 9000
#define BUFFER_SIZE 1024
int main() {
int sockfd;
char buffer[BUFFER_SIZE];
struct sockaddr_in server_addr, client_addr;
socklen_t addr_len = sizeof(client_addr);
// Create UDP socket
sockfd = socket(AF_INET, SOCK_DGRAM, 0);
// Setup server address
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_addr.s_addr = INADDR_ANY;
server_addr.sin_port = htons(PORT);
// Bind socket
bind(sockfd, (struct sockaddr*)&server_addr, sizeof(server_addr));
printf("UDP server listening on port %d\n", PORT);
while(1) {
// Receive data
int n = recvfrom(sockfd, buffer, BUFFER_SIZE, 0,
(struct sockaddr*)&client_addr, &addr_len);
buffer[n] = '\0';
printf("Received: %s\n", buffer);
// Send response
sendto(sockfd, "ACK", 3, 0,
(struct sockaddr*)&client_addr, addr_len);
}
return 0;
}
UDP Broadcast and Multicast
Broadcast (One-to-All in subnet)
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
# Send to broadcast address
broadcast_address = ('255.255.255.255', 9000)
sock.sendto(b"Broadcast message", broadcast_address)
Multicast (One-to-Many selected)
import socket
import struct
MCAST_GRP = '224.1.1.1'
MCAST_PORT = 5007
# Sender
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.sendto(b"Multicast message", (MCAST_GRP, MCAST_PORT))
# Receiver
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.bind(('', MCAST_PORT))
# Join multicast group
mreq = struct.pack("4sl", socket.inet_aton(MCAST_GRP),
socket.INADDR_ANY)
sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)
data, address = sock.recvfrom(1024)
UDP Maximum Packet Size
Theoretical Limits
IPv4:
- Max IP packet: 65,535 bytes
- IP header: 20 bytes (minimum)
- UDP header: 8 bytes
- Max UDP data: 65,507 bytes
IPv6:
- Max payload (jumbogram): 4,294,967,295 bytes
Practical Limits (MTU)
Ethernet MTU: 1500 bytes
- IP header: 20 bytes
- UDP header: 8 bytes
- Safe UDP data: 1472 bytes
To avoid fragmentation:
- Stay under 1472 bytes for IPv4
- Stay under 1452 bytes for IPv6
UDP Reliability Techniques
Since UDP doesn't provide reliability, applications must implement it:
1. Acknowledgments
Sender Receiver
| |
| Packet 1 |
|------------------------------->|
| |
| ACK 1 |
|<-------------------------------|
| |
| Packet 2 |
|------------------------------->|
| |
[timeout - no ACK received]
| |
| Packet 2 (resend) |
|------------------------------->|
| |
| ACK 2 |
|<-------------------------------|
2. Sequence Numbers
Application adds sequence numbers:
Packet 1: [Seq=1][Data]
Packet 2: [Seq=2][Data]
Packet 3: [Seq=3][Data]
Receiver detects missing packets
Requests retransmission if needed
3. Timeouts and Retries
import socket
import time
def send_with_retry(sock, data, address, max_retries=3):
for attempt in range(max_retries):
sock.sendto(data, address)
sock.settimeout(1.0)
try:
response, _ = sock.recvfrom(1024)
return response
except socket.timeout:
print(f"Retry {attempt + 1}/{max_retries}")
continue
raise Exception("Max retries exceeded")
UDP Advantages
- Low Latency: No connection setup, immediate transmission
- Low Overhead: 8-byte header vs TCP's 20+ bytes
- No Connection State: Simpler, uses less memory
- Broadcast/Multicast: Can send to multiple receivers
- Fast: No waiting for acknowledgments
- Transaction-Oriented: Good for request-response
UDP Disadvantages
- Unreliable: Packets may be lost, duplicated, or reordered
- No Flow Control: Can overwhelm receiver
- No Congestion Control: Can worsen network congestion
- No Security: No encryption (use DTLS for secure UDP)
- Application Complexity: Must implement reliability if needed
UDP Security Considerations
Vulnerabilities
- UDP Flood Attack: Overwhelm server with UDP packets
- UDP Amplification: Small request → large response (DNS, NTP)
- Spoofing: Easy to fake source IP (no handshake)
Mitigation
1. Rate limiting: Limit packets per second per source
2. Firewall rules: Block unnecessary UDP ports
3. Authentication: Verify sender identity
4. DTLS: Encrypted UDP (Datagram TLS)
DTLS (Datagram TLS)
Secure UDP communication:
UDP + TLS-style encryption = DTLS
Used in:
- WebRTC
- VPN protocols
- IoT devices
Monitoring UDP Traffic
Using tcpdump
# Capture UDP traffic on port 53 (DNS)
tcpdump -i any udp port 53
# Capture all UDP traffic
tcpdump -i any udp
# Save to file
tcpdump -i any udp -w udp_capture.pcap
# View UDP packet details
tcpdump -i any udp -vv -X
Using netstat
# Show UDP listening ports
netstat -un
# Show UDP statistics
netstat -su
# Show processes using UDP
netstat -unp
ELI10
UDP is like sending postcards:
TCP is like certified mail:
- You get confirmation it arrived
- Items arrive in order
- Lost mail is resent
- But takes longer
UDP is like postcards:
- Just drop it in the mailbox and go
- Super fast - no waiting
- But might get lost
- Might arrive out of order
- No way to know if it arrived
When to use UDP (postcards):
- Quick questions (DNS: "What's this address?")
- Live streaming (watching a game - who cares about 1 missed frame?)
- Online games (your position updates 60 times per second)
- Video calls (slight glitch is better than delay)
When to use TCP (certified mail):
- Important files
- Web pages
- Emails
- Banking transactions
Further Resources
HTTP/HTTPS
Overview
HTTP (HyperText Transfer Protocol) is the foundation of data communication on the web. HTTPS adds encryption for secure communication.
HTTP Basics
Request-Response Model
Client Server
HTTP Request ->
<- HTTP Response
HTTP Methods
| Method | Purpose | Idempotent | Safe |
|---|---|---|---|
| GET | Retrieve resource | Yes | Yes |
| POST | Create resource | No | No |
| PUT | Replace resource | Yes | No |
| PATCH | Partial update | No | No |
| DELETE | Remove resource | Yes | No |
| HEAD | Like GET, no body | Yes | Yes |
| OPTIONS | Describe options | Yes | Yes |
Status Codes
| Code | Meaning | Examples |
|---|---|---|
| 1xx | Informational | 100 Continue |
| 2xx | Success | 200 OK, 201 Created |
| 3xx | Redirection | 301 Moved, 304 Not Modified |
| 4xx | Client Error | 400 Bad Request, 404 Not Found |
| 5xx | Server Error | 500 Server Error, 503 Unavailable |
Headers
Request Headers:
Host: example.com
User-Agent: Mozilla/5.0
Accept: application/json
Authorization: Bearer token123
Cookie: session=abc123
Content-Type: application/json
Response Headers:
Content-Type: application/json
Content-Length: 256
Set-Cookie: session=def456
Cache-Control: max-age=3600
ETag: "12345abc"
HTTP Versions
| Version | Released | Features |
|---|---|---|
| HTTP/1.1 | 1997 | Keep-alive, chunked transfer |
| HTTP/2 | 2015 | Multiplexing, server push, binary |
| HTTP/3 | 2022 | QUIC protocol, faster |
REST API Design
Resource-Oriented
GET /users - List users
POST /users - Create user
GET /users/123 - Get user 123
PUT /users/123 - Update user 123
DELETE /users/123 - Delete user 123
GET /getUser?id=123 - Procedural (bad)
POST /createUser - Procedural (bad)
Request/Response Example
# Request
GET /users/123 HTTP/1.1
Host: api.example.com
Authorization: Bearer token
# Response
HTTP/1.1 200 OK
Content-Type: application/json
Content-Length: 156
{
"id": 123,
"name": "John",
"email": "john@example.com"
}
HTTPS (Secure HTTP)
Adds TLS encryption on top of HTTP:
HTTP over TLS = HTTPS
Benefits
- Encryption: Data unreadable to eavesdroppers
- Authentication: Verify server identity
- Integrity: Detect tampering
Certificate Process
1. Generate private/public key pair
2. Request certificate from CA
3. CA verifies and signs certificate
4. Browser verifies signature with CA's public key
Caching
Cache Headers
Cache-Control: max-age=3600 # Cache for 1 hour
Cache-Control: no-cache # Validate before use
Cache-Control: no-store # Don't cache
Cache-Control: private # Only browser cache
Cache-Control: public # Any cache can store
ETag: "12345" # Resource version
Conditional Requests
If-None-Match: "12345"
-> Returns 304 Not Modified if unchanged
Authentication
Basic Auth
Authorization: Basic base64(username:password)
Bearer Token
Authorization: Bearer eyJhbGc...
OAuth 2.0
Multi-step authorization flow for 3rd party apps
CORS (Cross-Origin Resource Sharing)
Enable browser to access cross-origin APIs:
Server Response:
Access-Control-Allow-Origin: *
Access-Control-Allow-Methods: GET, POST
Access-Control-Allow-Headers: Content-Type
Common Issues
404 Not Found
Resource doesn't exist
401 Unauthorized
Missing/invalid authentication
403 Forbidden
Authenticated but not allowed
429 Too Many Requests
Rate limit exceeded
503 Service Unavailable
Server temporarily down
Best Practices
1. Use Appropriate Methods
GET for reading (no side effects)
POST for creating
PUT for full replacement
PATCH for partial update
2. Meaningful Status Codes
200 OK for success
201 Created for new resource
204 No Content for delete
200 for everything (bad)
3. Versioning
/api/v1/users
/api/v2/users
4. Error Responses
{
"error": "Invalid input",
"details": {
"email": "Email format invalid"
}
}
ELI10
HTTP is like sending letters:
- GET: "What's the address of 123 Main St?"
- POST: "Please add my address to your system"
- PUT: "Update my address to..."
- DELETE: "Remove my address"
The server sends back a number (status code):
- 200: "Got it, here's what you asked for!"
- 404: "Can't find that address"
- 500: "I have a problem..."
HTTPS adds a sealed envelope so only the right person can read it!
Further Resources
DNS (Domain Name System)
Overview
DNS is the internet's phonebook that translates human-readable domain names (like example.com) into IP addresses (like 93.184.216.34) that computers use to identify each other on the network.
DNS Hierarchy
Root (.)
|
+-------------+-------------+
| | |
.com .org .net
| | |
example.com wikipedia.org archive.net
|
www.example.com
DNS Record Types
| Record Type | Purpose | Example |
|---|---|---|
| A | IPv4 address | example.com -> 93.184.216.34 |
| AAAA | IPv6 address | example.com -> 2606:2800:220:1:... |
| CNAME | Canonical name (alias) | www.example.com -> example.com |
| MX | Mail exchange server | example.com -> mail.example.com |
| NS | Name server | example.com -> ns1.example.com |
| TXT | Text information | SPF, DKIM records |
| PTR | Reverse DNS lookup | 34.216.184.93 -> example.com |
| SOA | Start of authority | Zone information |
| SRV | Service location | _service._proto.name |
DNS Query Process
1. User types "example.com" in browser
2. Browser checks local cache
3. If not cached, query DNS resolver (ISP or 8.8.8.8)
4. Resolver checks its cache
5. If not cached, recursive query:
Resolver → Root DNS Server
Root → "Ask .com server"
Resolver → .com TLD Server
TLD → "Ask example.com's nameserver"
Resolver → example.com's Nameserver
Nameserver → "IP is 93.184.216.34"
6. Resolver caches result and returns to browser
7. Browser connects to IP address
DNS Message Format
+---------------------------+
| Header | 12 bytes
+---------------------------+
| Question | Variable
+---------------------------+
| Answer | Variable
+---------------------------+
| Authority | Variable
+---------------------------+
| Additional | Variable
+---------------------------+
Header Format (12 bytes)
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
| ID |
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|QR| Opcode |AA|TC|RD|RA| Z | RCODE |
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
| QDCOUNT |
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
| ANCOUNT |
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
| NSCOUNT |
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
| ARCOUNT |
+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
Fields:
- ID: 16-bit identifier for matching requests/responses
- QR: Query (0) or Response (1)
- Opcode: Query type (0=standard, 1=inverse, 2=status)
- AA: Authoritative Answer
- TC: Truncated (message too long for UDP)
- RD: Recursion Desired
- RA: Recursion Available
- RCODE: Response code (0=no error, 3=name error)
- QDCOUNT: Number of questions
- ANCOUNT: Number of answers
- NSCOUNT: Number of authority records
- ARCOUNT: Number of additional records
DNS Query Example
Query (Request)
; DNS Query for example.com A record
; Header
ID: 0x1234
Flags: 0x0100 (standard query, recursion desired)
Questions: 1
Answer RRs: 0
Authority RRs: 0
Additional RRs: 0
; Question Section
example.com. IN A
Hexadecimal representation:
12 34 01 00 00 01 00 00 00 00 00 00
07 65 78 61 6d 70 6c 65 03 63 6f 6d 00
00 01 00 01
Response
; DNS Response for example.com A record
; Header
ID: 0x1234
Flags: 0x8180 (response, recursion available)
Questions: 1
Answer RRs: 1
Authority RRs: 0
Additional RRs: 0
; Question Section
example.com. IN A
; Answer Section
example.com. 86400 IN A 93.184.216.34
DNS Query Types
Recursive Query
Client asks DNS server to provide the final answer:
Client → Resolver: "What's example.com?"
Resolver → Root/TLD/Auth servers (multiple queries)
Resolver → Client: "It's 93.184.216.34"
Iterative Query
DNS server returns best answer it knows:
Client → Root: "What's example.com?"
Root → Client: "Ask .com server at 192.5.6.30"
Client → TLD: "What's example.com?"
TLD → Client: "Ask ns1.example.com at 192.0.2.1"
Client → Auth: "What's example.com?"
Auth → Client: "It's 93.184.216.34"
DNS Resource Record Format
Name: example.com
Type: A (1)
Class: IN (1) - Internet
TTL: 86400 (24 hours)
Data Length: 4
Data: 93.184.216.34
Common DNS Operations
Using dig (DNS lookup tool)
# Basic A record lookup
dig example.com
# Query specific record type
dig example.com MX
dig example.com AAAA
# Query specific DNS server
dig @8.8.8.8 example.com
# Reverse DNS lookup
dig -x 93.184.216.34
# Trace DNS resolution path
dig +trace example.com
# Short answer only
dig +short example.com
Using nslookup
# Basic lookup
nslookup example.com
# Query specific server
nslookup example.com 8.8.8.8
# Query specific record type
nslookup -type=MX example.com
Using host
# Simple lookup
host example.com
# Verbose output
host -v example.com
# Query MX records
host -t MX example.com
DNS Caching
Cache Levels
- Browser Cache: Short-lived (seconds to minutes)
- OS Cache: System-level DNS cache
- Router Cache: Local network cache
- ISP Resolver Cache: Hours to days
- Authoritative Server: The source of truth
TTL (Time To Live)
Controls how long records are cached:
example.com. 3600 IN A 93.184.216.34
^^^^
1 hour TTL
Flushing DNS Cache
# Windows
ipconfig /flushdns
# macOS
sudo dscacheutil -flushcache
# Linux (systemd-resolved)
sudo systemd-resolve --flush-caches
# Linux (nscd)
sudo /etc/init.d/nscd restart
DNS Security
DNS Spoofing/Cache Poisoning
Attack where fake DNS responses are injected:
Attacker intercepts DNS query
Attacker sends fake response: "bank.com -> evil.com"
Victim connects to attacker's server
Prevention: DNSSEC
DNSSEC (DNS Security Extensions)
Adds cryptographic signatures to DNS records:
1. Zone owner signs DNS records with private key
2. Public key published in DNS
3. Resolver verifies signature
4. Chain of trust from root to domain
Record Types:
- RRSIG: Contains signature
- DNSKEY: Public key
- DS: Delegation Signer (links parent to child)
DNS over HTTPS (DoH)
Encrypts DNS queries using HTTPS:
Client → DoH Server (port 443)
Encrypted: "What's example.com?"
Encrypted: "It's 93.184.216.34"
Providers:
- Cloudflare:
https://1.1.1.1/dns-query - Google:
https://dns.google/dns-query
DNS over TLS (DoT)
Encrypts DNS queries using TLS:
Client → DoT Server (port 853)
TLS encrypted DNS query/response
Public DNS Servers
| Provider | IPv4 | IPv6 | Features |
|---|---|---|---|
| 8.8.8.8, 8.8.4.4 | 2001:4860:4860::8888 | Fast, reliable | |
| Cloudflare | 1.1.1.1, 1.0.0.1 | 2606:4700:4700::1111 | Privacy-focused |
| Quad9 | 9.9.9.9 | 2620:fe::fe | Malware blocking |
| OpenDNS | 208.67.222.222 | 2620:119:35::35 | Content filtering |
DNS Load Balancing
Multiple A records for load distribution:
example.com. 300 IN A 192.0.2.1
example.com. 300 IN A 192.0.2.2
example.com. 300 IN A 192.0.2.3
Round-robin or geographic distribution of requests.
Common DNS Response Codes
| Code | Name | Meaning |
|---|---|---|
| 0 | NOERROR | Query successful |
| 1 | FORMERR | Format error |
| 2 | SERVFAIL | Server failure |
| 3 | NXDOMAIN | Domain doesn't exist |
| 4 | NOTIMP | Not implemented |
| 5 | REFUSED | Query refused |
DNS Best Practices
1. Use Multiple Nameservers
NS ns1.example.com (Primary)
NS ns2.example.com (Secondary)
2. Appropriate TTL Values
# Stable records (rarely change)
example.com. 86400 IN A 93.184.216.34
# Dynamic records (may change soon)
staging.example.com. 300 IN A 192.0.2.1
3. SPF Records for Email
example.com. IN TXT "v=spf1 mx include:_spf.google.com ~all"
4. DKIM for Email Authentication
default._domainkey.example.com. IN TXT "v=DKIM1; k=rsa; p=MIGfMA0..."
DNS Troubleshooting
Issue: Domain not resolving
# Check if domain exists
dig example.com
# Check all nameservers
dig example.com NS
dig @ns1.example.com example.com
# Check propagation
dig @8.8.8.8 example.com
dig @1.1.1.1 example.com
Issue: Slow DNS resolution
# Test query time
dig example.com | grep "Query time"
# Compare different DNS servers
dig @8.8.8.8 example.com | grep "Query time"
dig @1.1.1.1 example.com | grep "Query time"
Issue: NXDOMAIN (domain not found)
- Check domain registration
- Verify nameserver configuration
- Check DNS propagation time (up to 48 hours)
Zone File Example
$TTL 86400
@ IN SOA ns1.example.com. admin.example.com. (
2024011301 ; Serial
3600 ; Refresh
1800 ; Retry
604800 ; Expire
86400 ) ; Minimum TTL
; Name servers
IN NS ns1.example.com.
IN NS ns2.example.com.
; Mail servers
IN MX 10 mail1.example.com.
IN MX 20 mail2.example.com.
; A records
@ IN A 93.184.216.34
www IN A 93.184.216.34
mail1 IN A 192.0.2.1
mail2 IN A 192.0.2.2
ns1 IN A 192.0.2.10
ns2 IN A 192.0.2.11
; AAAA records (IPv6)
@ IN AAAA 2606:2800:220:1:248:1893:25c8:1946
; CNAME records
ftp IN CNAME www.example.com.
webmail IN CNAME mail1.example.com.
ELI10
DNS is like a phone book for the internet:
- Without DNS: "Visit 93.184.216.34" (hard to remember!)
- With DNS: "Visit example.com" (easy!)
When you type a website name:
- Your computer asks "Where is example.com?"
- DNS looks it up in its huge phone book
- DNS says "It's at 93.184.216.34"
- Your computer connects to that address
DNS servers are like helpers who:
- Remember answers (caching) so they can answer faster next time
- Ask other DNS servers if they don't know the answer
- Make sure everyone gets the same answer for the same website
Further Resources
mDNS (Multicast DNS)
Overview
mDNS (Multicast DNS) is a protocol that resolves hostnames to IP addresses within small networks without requiring a conventional DNS server. It's part of Zero Configuration Networking (Zeroconf) and enables devices to discover each other on local networks using the .local domain.
Why mDNS?
Traditional DNS Limitations
Problem: Home networks lack DNS servers
Traditional setup requires:
1. DNS server
2. Manual configuration
3. Static IP or DHCP integration
4. Administrative overhead
mDNS solution:
- No DNS server needed
- Automatic hostname resolution
- Zero configuration
- Works out of the box
Use Cases
1. Printer discovery
- printer.local → 192.168.1.100
2. File sharing
- macbook.local → 192.168.1.50
3. IoT devices
- raspberry-pi.local → 192.168.1.75
4. Local development
- webserver.local → 127.0.0.1
5. Service discovery
- Find all printers on network
- Find all file servers
How mDNS Works
Query Process
Device wants to find "printer.local"
1. Send multicast query to 224.0.0.251:5353
"Who has printer.local?"
2. All devices receive query
3. Device with hostname "printer" responds
"I'm printer.local at 192.168.1.100"
4. Querying device caches response
5. Direct communication established
Multicast Address
IPv4: 224.0.0.251
IPv6: ff02::fb
Port: 5353 (UDP)
All devices on local network listen to this address
mDNS Message Format
DNS-Compatible Format
mDNS uses standard DNS message format:
+---------------------------+
| Header |
+---------------------------+
| Question |
+---------------------------+
| Answer |
+---------------------------+
| Authority |
+---------------------------+
| Additional |
+---------------------------+
Header Fields
ID: Usually 0 (multicast)
QR: Query (0) or Response (1)
OPCODE: 0 (standard query)
AA: Authoritative Answer (1 for responses)
TC: Truncated
RD: Recursion Desired (0 for mDNS)
RA: Recursion Available (0 for mDNS)
RCODE: Response code
Questions: Number of questions
Answers: Number of answer RRs
Authority: Number of authority RRs
Additional: Number of additional RRs
mDNS Query Example
Query Message
Multicast to 224.0.0.251:5353
Question:
Name: printer.local
Type: A (IPv4 address)
Class: IN (Internet)
QU bit: 0 (multicast query)
Header:
ID: 0
Flags: 0x0000 (standard query)
Questions: 1
Answers: 0
Response Message
Multicast from 192.168.1.100:5353
Answer:
Name: printer.local
Type: A
Class: IN | Cache-Flush bit
TTL: 120 seconds
Data: 192.168.1.100
Header:
ID: 0
Flags: 0x8400 (authoritative answer)
Questions: 0
Answers: 1
mDNS Record Types
Common Record Types
| Type | Purpose | Example |
|---|---|---|
| A | IPv4 address | device.local → 192.168.1.10 |
| AAAA | IPv6 address | device.local → fe80::1 |
| PTR | Pointer (service discovery) | _http._tcp.local → webserver |
| SRV | Service location | webserver._http._tcp.local → device.local:80 |
| TXT | Text information | Service metadata |
Service Discovery (DNS-SD)
PTR Record: Browse services
_http._tcp.local → webserver._http._tcp.local
SRV Record: Service location
webserver._http._tcp.local
Target: myserver.local
Port: 8080
Priority: 0
Weight: 0
TXT Record: Service metadata
webserver._http._tcp.local
"path=/admin"
"version=1.0"
A Record: IP address
myserver.local → 192.168.1.50
mDNS Features
1. Multicast Queries
Traditional DNS (unicast):
Client → DNS Server: "What's example.com?"
DNS Server → Client: "93.184.216.34"
mDNS (multicast):
Client → All devices: "Who has printer.local?"
Printer → All devices: "I'm 192.168.1.100"
Benefits:
- No dedicated server
- All devices hear query
- Multiple responses possible
2. Known-Answer Suppression
Query includes known answers to avoid redundant responses
Client has cached: printer.local → 192.168.1.100
Query:
Question: printer.local?
Known Answer: 192.168.1.100 (TTL > 50% remaining)
Printer sees cached answer is still valid
→ Doesn't respond (saves bandwidth)
3. Cache-Flush Bit
Purpose: Invalidate old cache entries
Response with cache-flush:
printer.local → 192.168.1.100
Class: IN | 0x8000 (cache-flush bit set)
Receivers:
- Flush old records for printer.local
- Cache new record
- Prevents stale data
4. Continuous Verification
Querier sends query even if cached
- Verify host still exists
- Detect IP changes
- Maintain fresh cache
If no response → remove from cache
5. Graceful Shutdown
Device going offline:
Send goodbye message:
printer.local → 192.168.1.100
TTL: 0 (indicates removal)
Other devices:
- Remove from cache immediately
- Don't wait for timeout
Service Discovery with DNS-SD
Browsing Services
Query:
_services._dns-sd._udp.local PTR?
Response (all available service types):
_http._tcp.local
_printer._tcp.local
_ssh._tcp.local
_sftp-ssh._tcp.local
Finding Specific Service
Query:
_http._tcp.local PTR?
Response (all HTTP services):
webserver._http._tcp.local
api._http._tcp.local
admin._http._tcp.local
Getting Service Details
Query:
webserver._http._tcp.local SRV?
webserver._http._tcp.local TXT?
Response (SRV):
Target: myserver.local
Port: 8080
Priority: 0
Weight: 0
Response (TXT):
path=/
version=2.0
https=true
Then resolve:
myserver.local A? → 192.168.1.50
mDNS Implementation
Python Example (Query)
import socket
import struct
MDNS_ADDR = '224.0.0.251'
MDNS_PORT = 5353
def query_mdns(hostname):
# Create UDP socket
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 255)
sock.settimeout(2)
# Build DNS query
# Header: ID=0, Flags=0, Questions=1
query = struct.pack('!HHHHHH', 0, 0, 1, 0, 0, 0)
# Question: hostname, type A, class IN
for part in hostname.split('.'):
query += bytes([len(part)]) + part.encode()
query += b'\x00' # End of name
query += struct.pack('!HH', 1, 1) # Type A, Class IN
# Send query
sock.sendto(query, (MDNS_ADDR, MDNS_PORT))
# Receive responses
responses = []
try:
while True:
data, addr = sock.recvfrom(1024)
responses.append((data, addr))
except socket.timeout:
pass
sock.close()
return responses
# Usage
responses = query_mdns('printer.local')
for data, addr in responses:
print(f"Response from {addr}")
Python Example (Responder using zeroconf)
from zeroconf import ServiceInfo, Zeroconf
import socket
# Get local IP
hostname = socket.gethostname()
local_ip = socket.gethostbyname(hostname)
# Create service info
info = ServiceInfo(
"_http._tcp.local.",
"My Web Server._http._tcp.local.",
addresses=[socket.inet_aton(local_ip)],
port=8080,
properties={
'path': '/',
'version': '1.0'
},
server=f"{hostname}.local."
)
# Register service
zeroconf = Zeroconf()
zeroconf.register_service(info)
print(f"Service registered: {hostname}.local:8080")
try:
input("Press Enter to unregister...\n")
finally:
zeroconf.unregister_service(info)
zeroconf.close()
Avahi (Linux)
# Install Avahi
sudo apt-get install avahi-daemon avahi-utils
# Check hostname
avahi-resolve -n hostname.local
# Browse services
avahi-browse -a
# Browse specific service
avahi-browse _http._tcp
# Publish service
avahi-publish -s "My Service" _http._tcp 8080 path=/
Bonjour (macOS)
# Resolve hostname
dns-sd -G v4 hostname.local
# Browse services
dns-sd -B _http._tcp
# Resolve service
dns-sd -L "My Service" _http._tcp
# Register service
dns-sd -R "My Service" _http._tcp . 8080 path=/
Windows
# Windows 10+ includes mDNS support
# Resolve via PowerShell
Resolve-DnsName hostname.local
# Or use Bonjour SDK
# Download from Apple Developer
mDNS Service Naming
Format
<Instance>._<Service>._<Transport>.local
Examples:
My Printer._printer._tcp.local
Living Room._airplay._tcp.local
Office Server._smb._tcp.local
Kitchen Speaker._raop._tcp.local
Common Service Types
_http._tcp Web server
_https._tcp Secure web server
_ssh._tcp SSH server
_sftp-ssh._tcp SFTP over SSH
_ftp._tcp FTP server
_smb._tcp Samba/Windows file sharing
_afpovertcp._tcp Apple File Protocol
_printer._tcp Printer
_ipp._tcp Internet Printing Protocol
_airplay._tcp AirPlay
_raop._tcp Remote Audio Output Protocol
_spotify-connect._tcp Spotify Connect
mDNS Traffic Analysis
Capturing mDNS
# tcpdump
sudo tcpdump -i any -n port 5353
# Wireshark
# Filter: udp.port == 5353
# Follow: Right-click → Follow → UDP Stream
Example Capture
Query:
192.168.1.10 → 224.0.0.251
DNS Query: printer.local A?
Response:
192.168.1.100 → 224.0.0.251
DNS Answer: printer.local → 192.168.1.100 (TTL 120)
mDNS Security Considerations
Vulnerabilities
- No Authentication
Anyone can claim to be "printer.local"
No verification of identity
Potential for spoofing
- Local Network Only
mDNS doesn't cross routers
Limited to link-local multicast
Good for security (confined to LAN)
- Information Disclosure
Services broadcast their presence
Attackers can enumerate:
- Device names
- Service types
- IP addresses
- Software versions
- Name Conflicts
Two devices with same hostname
Both respond to queries
Can cause confusion
Mitigation
1. Firewall rules
- Block port 5353 on external interfaces
- Allow only on trusted LANs
2. VLANs
- Separate guest network
- Prevent mDNS between VLANs
3. Unique hostnames
- Avoid generic names
- Include random identifier
4. Service filtering
- Only advertise necessary services
- Remove unused service announcements
mDNS Performance
Bandwidth Usage
Typical traffic:
- Query: ~50 bytes
- Response: ~100 bytes
- Continuous verification: ~1-2 queries/minute
Low bandwidth impact
Efficient for local networks
Cache Timing
TTL values:
- Typical: 120 seconds (2 minutes)
- High priority: 10 seconds
- Low priority: 4500 seconds (75 minutes)
Refresh at 80% of TTL
Query again at 90% of TTL
Remove at 100% of TTL
Troubleshooting mDNS
Device not responding
# 1. Check mDNS daemon
sudo systemctl status avahi-daemon # Linux
sudo launchctl list | grep mDNS # macOS
# 2. Test multicast
ping -c 3 224.0.0.251
# 3. Check firewall
sudo iptables -L | grep 5353
sudo ufw status
# 4. Capture traffic
sudo tcpdump -i any port 5353
# 5. Resolve manually
avahi-resolve -n device.local
dns-sd -G v4 device.local
Name conflicts
Error: "hostname.local already in use"
Solutions:
1. Rename device
- hostname.local → hostname-2.local
- Automatic on many systems
2. Check for duplicates
- Ensure unique hostnames
- Search network for conflicts
Slow resolution
Causes:
- Network congestion
- Many mDNS devices
- Packet loss
Solutions:
- Reduce query frequency
- Use unicast if possible
- Cache aggressively
mDNS vs DNS
| Feature | Traditional DNS | mDNS |
|---|---|---|
| Server | Centralized server | Distributed (all devices) |
| Configuration | Manual setup | Zero configuration |
| Scope | Internet-wide | Local network only |
| Domain | Any TLD | .local only |
| Protocol | Unicast | Multicast |
| Port | 53 | 5353 |
| Security | DNSSEC available | No authentication |
ELI10
mDNS is like asking a question to everyone in a classroom:
Traditional DNS:
- Raise your hand and ask the teacher
- Teacher has a list of everyone's desks
- Teacher tells you where Alice sits
mDNS (Multicast DNS):
- Stand up and ask: "Where's Alice?"
- Alice hears you and responds: "I'm here at desk 5!"
- Everyone hears both question and answer
- Next time someone asks, they already know
Benefits:
- No need for a teacher (DNS server)
- Works immediately
- Everyone learns everyone else's location
Limitations:
- Only works in one classroom (local network)
- Can't ask about people in other classrooms
- Everyone hears everything (less private)
Real Examples:
- "Where's the printer?" → "printer.local is at 192.168.1.100"
- "Where's my MacBook?" → "macbook.local is at 192.168.1.50"
- "Any web servers?" → "myserver.local has HTTP on port 8080"
It's perfect for homes and small offices where you just want things to work!
Further Resources
Firewalls
Overview
A firewall is a network security system that monitors and controls incoming and outgoing network traffic based on predetermined security rules. It acts as a barrier between trusted internal networks and untrusted external networks (like the internet).
Firewall Types
1. Packet Filtering Firewall
How it works:
- Inspects individual packets
- Makes decisions based on header information
- Stateless (doesn't track connections)
Checks:
- Source IP address
- Destination IP address
- Source port
- Destination port
- Protocol (TCP, UDP, ICMP)
Example Rule:
ALLOW TCP from 192.168.1.0/24 to any port 80
DENY TCP from any to any port 23
Decision Process:
Incoming packet:
Src: 192.168.1.10:54321
Dst: 10.0.0.5:80
Protocol: TCP
Check rules top-to-bottom:
Rule 1: Allow 192.168.1.0/24 to port 80 → MATCH
Action: ALLOW
Packet forwarded
Pros:
- Fast (minimal inspection)
- Low resource usage
- Simple configuration
Cons:
- No state tracking
- Can't detect complex attacks
- Vulnerable to IP spoofing
2. Stateful Inspection Firewall
How it works:
- Tracks connection state
- Maintains state table
- Understands context of traffic
State Table Example:
Src IP Src Port Dst IP Dst Port State Protocol
192.168.1.10 54321 93.184.216.34 80 ESTABLISHED TCP
192.168.1.11 54322 8.8.8.8 53 NEW UDP
192.168.1.10 54323 10.0.0.5 22 SYN_SENT TCP
TCP Connection Tracking:
Client → Server: SYN
State: NEW
Server → Client: SYN-ACK
State: ESTABLISHED
Client → Server: ACK
State: ESTABLISHED
... data transfer ...
Client → Server: FIN
State: CLOSING
Server → Client: FIN-ACK
State: CLOSED
Example Rule:
# Outbound rule
ALLOW TCP from 192.168.1.0/24 to any port 80 STATE NEW,ESTABLISHED
# Return traffic automatically allowed
# (tracked in state table)
Pros:
- Understands connection context
- Better security than packet filtering
- Prevents spoofing attacks
- Allows related traffic
Cons:
- More resource intensive
- State table can be exhausted
- Performance impact at scale
3. Application Layer Firewall (Proxy Firewall)
How it works:
- Operates at Layer 7 (Application)
- Acts as intermediary (proxy)
- Deep packet inspection
- Understands application protocols
Proxy Flow:
Client → Proxy → Server
Client connects to proxy
Proxy inspects full request
Proxy makes decision
Proxy connects to server (if allowed)
Proxy relays response to client
Inspection Capabilities:
HTTP/HTTPS:
- URL filtering
- Content scanning
- Malware detection
- Data loss prevention
FTP:
- Command filtering
- File type restrictions
SMTP:
- Spam filtering
- Attachment scanning
Example:
HTTP Request:
GET /admin.php HTTP/1.1
Host: example.com
Proxy checks:
1. Is /admin.php allowed? → NO
2. Block request
3. Return 403 Forbidden
Pros:
- Deep inspection
- Understands application protocols
- Can filter content
- Hides internal network
- Logging and auditing
Cons:
- Significant performance impact
- Complex configuration
- May break some applications
- Single point of failure
4. Next-Generation Firewall (NGFW)
Combines:
- Traditional firewall functions
- Intrusion Prevention System (IPS)
- Application awareness
- SSL/TLS inspection
- Advanced threat protection
Features:
1. Deep Packet Inspection (DPI)
- Full packet content analysis
2. Application Control
- Block Facebook but allow LinkedIn
- Control by application, not just port
3. User Identity
- Rules based on user/group
- Active Directory integration
4. Threat Intelligence
- Malware detection
- Botnet protection
- Zero-day protection
5. SSL Inspection
- Decrypt HTTPS traffic
- Inspect encrypted content
- Re-encrypt and forward
Example NGFW Rule:
DENY application "BitTorrent" for group "Employees"
ALLOW application "Salesforce" for group "Sales"
BLOCK malware signature "Trojan.Generic.123"
Firewall Architectures
1. Packet Filtering Router
Internet → → [Router with ACL] → → Internal Network
Simple, single layer of protection
2. Dual-Homed Host
Internet → → [Firewall with 2 NICs] → → Internal Network
(All traffic through firewall)
Complete traffic control
3. Screened Host
Internet → → [Router] → → [Firewall Host] → → Internal Network
Router filters basic traffic
Firewall provides additional protection
4. Screened Subnet (DMZ)
Internet → → [External FW] → → [DMZ] → → [Internal FW] → → Internal Network
(Web, Mail)
Public services in DMZ
Internal network isolated
DMZ Example:
External Firewall Rules:
- Allow HTTP/HTTPS to web server (DMZ)
- Allow SMTP to mail server (DMZ)
- Deny all to internal network
Internal Firewall Rules:
- Allow web server to database (specific port)
- Allow mail server to internal mail (specific port)
- Deny all other DMZ traffic to internal
Firewall Rules
Rule Components
1. Source: Where traffic originates
2. Destination: Where traffic is going
3. Service/Port: What service (HTTP, SSH, etc.)
4. Action: Allow, Deny, Reject
5. Direction: Inbound, Outbound
6. State: NEW, ESTABLISHED, RELATED
Rule Example (iptables)
# Allow SSH from specific network
iptables -A INPUT -s 192.168.1.0/24 -p tcp --dport 22 -j ACCEPT
# Allow established connections
iptables -A INPUT -m state --state ESTABLISHED,RELATED -j ACCEPT
# Allow HTTP and HTTPS
iptables -A INPUT -p tcp --dport 80 -j ACCEPT
iptables -A INPUT -p tcp --dport 443 -j ACCEPT
# Drop everything else
iptables -A INPUT -j DROP
Rule Ordering
Important: Rules processed top-to-bottom, first match wins
# WRONG ORDER:
1. DENY all
2. ALLOW HTTP port 80 → Never reached!
# CORRECT ORDER:
1. ALLOW HTTP port 80
2. DENY all
Default Policy
# Default DENY (whitelist approach - more secure)
iptables -P INPUT DROP
iptables -P FORWARD DROP
iptables -P OUTPUT ACCEPT
# Then explicitly allow needed services
# Default ALLOW (blacklist approach - less secure)
iptables -P INPUT ACCEPT
iptables -P FORWARD ACCEPT
iptables -P OUTPUT ACCEPT
# Then explicitly block dangerous services
Common Firewall Configurations
1. Linux iptables
View rules:
iptables -L -v -n
Basic web server protection:
# Flush existing rules
iptables -F
# Default policies
iptables -P INPUT DROP
iptables -P FORWARD DROP
iptables -P OUTPUT ACCEPT
# Allow loopback
iptables -A INPUT -i lo -j ACCEPT
# Allow established connections
iptables -A INPUT -m state --state ESTABLISHED,RELATED -j ACCEPT
# Allow SSH (from specific network)
iptables -A INPUT -s 192.168.1.0/24 -p tcp --dport 22 -j ACCEPT
# Allow HTTP/HTTPS
iptables -A INPUT -p tcp --dport 80 -j ACCEPT
iptables -A INPUT -p tcp --dport 443 -j ACCEPT
# Allow ping
iptables -A INPUT -p icmp --icmp-type echo-request -j ACCEPT
# Log dropped packets
iptables -A INPUT -j LOG --log-prefix "DROPPED: "
iptables -A INPUT -j DROP
# Save rules
iptables-save > /etc/iptables/rules.v4
2. Linux ufw (Uncomplicated Firewall)
# Enable firewall
ufw enable
# Default policies
ufw default deny incoming
ufw default allow outgoing
# Allow SSH
ufw allow 22/tcp
# Allow HTTP/HTTPS
ufw allow 80/tcp
ufw allow 443/tcp
# Allow from specific IP
ufw allow from 192.168.1.100
# Allow specific port from specific IP
ufw allow from 192.168.1.100 to any port 3306
# View rules
ufw status numbered
# Delete rule
ufw delete 5
3. Linux firewalld
# View zones
firewall-cmd --get-active-zones
# Add service to zone
firewall-cmd --zone=public --add-service=http
firewall-cmd --zone=public --add-service=https
# Add port
firewall-cmd --zone=public --add-port=8080/tcp
# Add rich rule
firewall-cmd --zone=public --add-rich-rule='rule family="ipv4" source address="192.168.1.0/24" service name="ssh" accept'
# Make permanent
firewall-cmd --runtime-to-permanent
# Reload
firewall-cmd --reload
4. Windows Firewall
# View rules
Get-NetFirewallRule
# Enable firewall
Set-NetFirewallProfile -Profile Domain,Public,Private -Enabled True
# Allow inbound port
New-NetFirewallRule -DisplayName "Allow HTTP" -Direction Inbound -LocalPort 80 -Protocol TCP -Action Allow
# Allow program
New-NetFirewallRule -DisplayName "My App" -Direction Inbound -Program "C:\App\myapp.exe" -Action Allow
# Block IP address
New-NetFirewallRule -DisplayName "Block IP" -Direction Inbound -RemoteAddress 10.0.0.5 -Action Block
Port Knocking
Concept: Hidden service that opens after specific sequence
Example:
# Ports closed by default
# Client knocks sequence: 1234, 5678, 9012
nc -z server.com 1234
nc -z server.com 5678
nc -z server.com 9012
# Server detects sequence, opens SSH port 22 for client IP
# Client can now SSH
# After timeout, port closes again
Configuration (knockd):
[openSSH]
sequence = 1234,5678,9012
seq_timeout = 10
command = /sbin/iptables -I INPUT -s %IP% -p tcp --dport 22 -j ACCEPT
tcpflags = syn
[closeSSH]
sequence = 9012,5678,1234
seq_timeout = 10
command = /sbin/iptables -D INPUT -s %IP% -p tcp --dport 22 -j ACCEPT
tcpflags = syn
NAT (Network Address Translation)
Source NAT (SNAT) / Masquerading
Purpose: Hide internal IPs behind single public IP
# iptables NAT
iptables -t nat -A POSTROUTING -o eth0 -j MASQUERADE
# Or specific IP
iptables -t nat -A POSTROUTING -o eth0 -j SNAT --to-source 203.0.113.5
Traffic Flow:
Internal: 192.168.1.10:5000 → Internet
External: 203.0.113.5:6000 → Internet
Firewall tracks connection:
192.168.1.10:5000 ↔ 203.0.113.5:6000
Return traffic:
Internet → 203.0.113.5:6000
Firewall translates back to: 192.168.1.10:5000
Destination NAT (DNAT) / Port Forwarding
Purpose: Expose internal service on public IP
# Forward public port 80 to internal web server
iptables -t nat -A PREROUTING -i eth0 -p tcp --dport 80 -j DNAT --to-destination 192.168.1.20:80
# Forward public port 2222 to internal SSH
iptables -t nat -A PREROUTING -i eth0 -p tcp --dport 2222 -j DNAT --to-destination 192.168.1.10:22
Traffic Flow:
Internet → 203.0.113.5:80
Firewall translates to: 192.168.1.20:80
Web server processes request
Response: 192.168.1.20:80 → Internet
Firewall translates from: 203.0.113.5:80 → Internet
Firewall Evasion Techniques (for awareness)
1. Fragmentation
Split malicious payload across fragments
Some firewalls don't reassemble
2. IP Spoofing
Fake source IP address
Bypass source-based rules
3. Tunneling
Encapsulate forbidden traffic in allowed protocol
Example: SSH tunnel, DNS tunnel, ICMP tunnel
4. Encryption
Encrypt malicious traffic
Firewall can't inspect without SSL inspection
Defense:
- Fragment reassembly
- Anti-spoofing rules
- Protocol validation
- SSL/TLS inspection
- Deep packet inspection
Firewall Logging
What to Log
1. Blocked connections
2. Allowed critical connections
3. Rule changes
4. Authentication events
5. Anomalies (port scans, floods)
iptables Logging
# Log dropped packets
iptables -A INPUT -j LOG --log-prefix "DROPPED INPUT: " --log-level 4
iptables -A INPUT -j DROP
# Log accepted SSH
iptables -A INPUT -p tcp --dport 22 -j LOG --log-prefix "SSH ACCEPT: "
iptables -A INPUT -p tcp --dport 22 -j ACCEPT
Log Analysis
# View firewall logs (typical locations)
tail -f /var/log/syslog
tail -f /var/log/messages
tail -f /var/log/kern.log
# Search for dropped packets
grep "DROPPED" /var/log/syslog
# Count connections by source IP
grep "DROPPED" /var/log/syslog | awk '{print $NF}' | sort | uniq -c | sort -n
Firewall Best Practices
1. Default Deny
Block everything by default
Explicitly allow only needed services
2. Principle of Least Privilege
Open only necessary ports
Restrict to specific sources when possible
Time-based rules when appropriate
3. Defense in Depth
Multiple layers:
- Perimeter firewall
- Host-based firewalls
- Network segmentation
- Application firewalls
4. Regular Updates
- Keep firewall software updated
- Review rules periodically
- Remove unused rules
- Update threat signatures (NGFW)
5. Monitoring and Alerts
- Enable logging
- Set up alerts for anomalies
- Regular log reviews
- Incident response plan
6. Testing
- Test rules before production
- Verify deny rules work
- Check for unintended access
- Regular security audits
Troubleshooting Firewall Issues
Can't connect to service
# 1. Check if service is running
systemctl status nginx
# 2. Check if service is listening
netstat -tuln | grep :80
ss -tuln | grep :80
# 3. Check firewall rules
iptables -L -n -v
ufw status
firewall-cmd --list-all
# 4. Check logs
tail -f /var/log/syslog | grep UFW
journalctl -f -u firewalld
# 5. Test from different source
curl http://server-ip
telnet server-ip 80
Connection works locally but not remotely
# Likely firewall blocking external access
# Check INPUT chain
iptables -L INPUT -n -v
# Temporarily allow (testing only!)
iptables -I INPUT -p tcp --dport 80 -j ACCEPT
# If works, add permanent rule
Rule not working
# Check rule order
iptables -L -n -v --line-numbers
# Rules processed top-to-bottom
# Earlier DENY rule might catch traffic before ALLOW
# Reorder rules
iptables -I INPUT 1 -p tcp --dport 80 -j ACCEPT
ELI10
A firewall is like a security guard at a building entrance:
Security Guard (Firewall):
- Checks everyone coming in and out
- Has a list of rules (who's allowed, who's not)
- Blocks suspicious people
- Keeps a log of who enters
Types of Security:
-
Basic Guard (Packet Filter):
- Checks ID cards only
- Fast but simple
-
Smart Guard (Stateful):
- Remembers who entered
- Allows them to leave
- Tracks conversations
-
Super Guard (Application Layer):
- Opens bags
- Checks what you're carrying
- Very thorough but slower
-
AI Guard (NGFW):
- Facial recognition
- Detects threats automatically
- Learns from experience
Rules Example:
- "Allow employees" (like allowing HTTP port 80)
- "Block suspicious visitors" (like blocking unknown IPs)
- "Only executives can enter executive floor" (like restricting SSH to specific IPs)
DMZ is like a reception area:
- Visitors wait here
- Can't go into main office
- Receptionists (DMZ servers) handle requests
Further Resources
STUN (Session Traversal Utilities for NAT)
Overview
STUN is a standardized network protocol that allows clients behind NAT (Network Address Translation) to discover their public IP address and the type of NAT they are behind. This information is crucial for establishing peer-to-peer connections in applications like VoIP, video conferencing, and WebRTC.
The NAT Problem
Why STUN is Needed
Private Network NAT Router Internet
(Public IP)
+------------------+ +---------+ +----------------+
| PC1: 192.168.1.10| ---> | Router | ---> | Other peer |
| PC2: 192.168.1.11| | External IP: | wants to |
| PC3: 192.168.1.12| | 203.0.113.5 | connect to you |
+------------------+ +---------+ +----------------+
Problem: How does external peer know your public IP and port?
Solution: STUN server tells you!
Without STUN
Peer A (behind NAT) wants to connect to Peer B
Peer A knows only: 192.168.1.10 (private IP)
Peer B needs: 203.0.113.5:54321 (public IP:port)
Peer A can't tell Peer B how to reach it L
With STUN
Peer A queries STUN server
STUN server responds: "I see you as 203.0.113.5:54321"
Peer A tells Peer B: "Connect to 203.0.113.5:54321"
Peer B connects successfully
STUN Architecture
Client STUN Server Peer
(Behind NAT) (Public IP)
| | |
| STUN Binding Request | |
|--------------------------->| |
| | |
| STUN Binding Response | |
|<---------------------------| |
| (Your public IP:Port) | |
| | |
| Send public IP:Port | |
|-------------------------------------------------->|
| | |
| Direct connection established |
|<------------------------------------------------->|
STUN Message Format
Message Structure
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|0 0| STUN Message Type | Message Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Magic Cookie |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| Transaction ID (96 bits) |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Attributes |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Header Fields
-
Message Type (16 bits):
- Class: Request (0x00), Success Response (0x01), Error Response (0x11)
- Method: Binding (0x001)
-
Message Length (16 bits):
- Length of attributes (excluding 20-byte header)
-
Magic Cookie (32 bits):
- Fixed value: 0x2112A442
- Helps distinguish STUN from other protocols
-
Transaction ID (96 bits):
- Unique identifier for matching requests/responses
Message Types
| Type | Value | Description |
|---|---|---|
| Binding Request | 0x0001 | Request public IP/port |
| Binding Response | 0x0101 | Success response with address |
| Binding Error | 0x0111 | Error response |
STUN Attributes
Common Attributes
| Attribute | Type | Description |
|---|---|---|
| MAPPED-ADDRESS | 0x0001 | Reflexive transport address (legacy) |
| XOR-MAPPED-ADDRESS | 0x0020 | XORed reflexive address (preferred) |
| USERNAME | 0x0006 | Username for authentication |
| MESSAGE-INTEGRITY | 0x0008 | HMAC-SHA1 hash |
| ERROR-CODE | 0x0009 | Error code and reason |
| UNKNOWN-ATTRIBUTES | 0x000A | Unknown required attributes |
| REALM | 0x0014 | Realm for authentication |
| NONCE | 0x0015 | Nonce for digest authentication |
| SOFTWARE | 0x8022 | Software version |
| FINGERPRINT | 0x8028 | CRC-32 of message |
XOR-MAPPED-ADDRESS Format
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|0 0 0 0 0 0 0 0| Family | X-Port |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| X-Address (Variable) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Family: 0x01 (IPv4), 0x02 (IPv6)
X-Port: Port XORed with most significant 16 bits of magic cookie
X-Address: IP address XORed with magic cookie (and transaction ID for IPv6)
Why XOR?
- Prevents middle boxes from modifying the address
- Some NAT devices inspect and modify IP addresses in packets
STUN Transaction Example
Binding Request
Client → STUN Server (UDP port 3478)
Message Type: Binding Request (0x0001)
Message Length: 0
Magic Cookie: 0x2112A442
Transaction ID: 0xB7E7A701BC34D686FA87DFAE
No attributes in basic request
Hexadecimal:
00 01 00 00 21 12 A4 42
B7 E7 A7 01 BC 34 D6 86
FA 87 DF AE
Binding Response
STUN Server → Client
Message Type: Binding Response (0x0101)
Message Length: 12 (length of attributes)
Magic Cookie: 0x2112A442
Transaction ID: 0xB7E7A701BC34D686FA87DFAE (same as request)
Attributes:
XOR-MAPPED-ADDRESS:
Family: IPv4 (0x01)
Port: 54321 (XORed)
IP: 203.0.113.5 (XORed)
Information extracted:
Your public IP address: 203.0.113.5
Your public port: 54321
NAT binding created: 192.168.1.10:5000 ↔ 203.0.113.5:54321
NAT Types Discovered by STUN
1. Full Cone NAT
Internal: 192.168.1.10:5000
NAT creates mapping:
192.168.1.10:5000 ↔ 203.0.113.5:6000
Any external host can send to 203.0.113.5:6000
→ Forwarded to 192.168.1.10:5000
Best for P2P (easy to traverse)
2. Restricted Cone NAT
Internal: 192.168.1.10:5000
NAT creates mapping:
192.168.1.10:5000 ↔ 203.0.113.5:6000
External host 1.2.3.4 can send to 203.0.113.5:6000
ONLY IF 192.168.1.10:5000 previously sent to 1.2.3.4
Moderate difficulty to traverse
3. Port Restricted Cone NAT
Internal: 192.168.1.10:5000
NAT creates mapping:
192.168.1.10:5000 ↔ 203.0.113.5:6000
External host 1.2.3.4:7000 can send to 203.0.113.5:6000
ONLY IF 192.168.1.10:5000 previously sent to 1.2.3.4:7000
More difficult to traverse
4. Symmetric NAT
Internal: 192.168.1.10:5000
NAT creates different mappings per destination:
To host A: 192.168.1.10:5000 ↔ 203.0.113.5:6000
To host B: 192.168.1.10:5000 ↔ 203.0.113.5:6001
To host C: 192.168.1.10:5000 ↔ 203.0.113.5:6002
Difficult to traverse (may need TURN relay)
STUN Usage in ICE
ICE (Interactive Connectivity Establishment) uses STUN:
ICE Candidate Gathering
1. Host Candidate:
Local IP: 192.168.1.10:5000
2. Server Reflexive Candidate (from STUN):
Public IP: 203.0.113.5:6000
3. Relayed Candidate (from TURN):
Relay IP: 198.51.100.1:7000
Try connections in order:
1. Direct (host to host)
2. Through NAT (server reflexive)
3. Through relay (last resort)
WebRTC Connection Flow
Peer A STUN Server Peer B
| | |
| Get my public IP | |
|--------------------------->| |
| | |
| 203.0.113.5:6000 | |
|<---------------------------| |
| | |
| Exchange candidates via signaling server |
|<------------------------------------------------->|
| | |
| Try connection | |
|<------------------------------------------------->|
| Connectivity check (STUN) | |
|<------------------------------------------------->|
| | |
| Connection established | |
|<=================================================>|
STUN Authentication
Short-Term Credentials
Request:
USERNAME: "alice:bob"
MESSAGE-INTEGRITY: HMAC-SHA1(message, password)
Server validates:
1. Check username exists
2. Compute HMAC with stored password
3. Compare with MESSAGE-INTEGRITY
4. Accept or reject
Long-Term Credentials
Request 1 (no credentials):
Binding Request
Response 1:
Error 401 Unauthorized
REALM: "example.com"
NONCE: "random-nonce-12345"
Request 2 (with credentials):
USERNAME: "alice"
REALM: "example.com"
NONCE: "random-nonce-12345"
MESSAGE-INTEGRITY: HMAC-SHA1(message, MD5(username:realm:password))
Response 2:
Binding Success Response
XOR-MAPPED-ADDRESS: ...
STUN Client Implementation
Python Example
import socket
import struct
import hashlib
import hmac
STUN_SERVER = "stun.l.google.com"
STUN_PORT = 19302
MAGIC_COOKIE = 0x2112A442
def create_stun_binding_request():
# Message type: Binding Request (0x0001)
msg_type = 0x0001
# Message length: 0 (no attributes)
msg_length = 0
# Transaction ID: 96 random bits
transaction_id = os.urandom(12)
# Pack header
header = struct.pack(
'!HHI',
msg_type,
msg_length,
MAGIC_COOKIE
) + transaction_id
return header, transaction_id
def parse_stun_response(data, transaction_id):
# Parse header
msg_type, msg_length, magic_cookie = struct.unpack('!HHI', data[:8])
recv_transaction_id = data[8:20]
# Verify transaction ID
if recv_transaction_id != transaction_id:
raise Exception("Transaction ID mismatch")
# Parse attributes
offset = 20
while offset < len(data):
attr_type, attr_length = struct.unpack('!HH', data[offset:offset+4])
offset += 4
if attr_type == 0x0020: # XOR-MAPPED-ADDRESS
# Parse XOR-MAPPED-ADDRESS
family = data[offset + 1]
x_port = struct.unpack('!H', data[offset+2:offset+4])[0]
x_ip = struct.unpack('!I', data[offset+4:offset+8])[0]
# Un-XOR
port = x_port ^ (MAGIC_COOKIE >> 16)
ip = x_ip ^ MAGIC_COOKIE
ip_addr = socket.inet_ntoa(struct.pack('!I', ip))
return ip_addr, port
offset += attr_length
return None, None
def get_public_ip_port():
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(3)
try:
# Create and send binding request
request, transaction_id = create_stun_binding_request()
sock.sendto(request, (STUN_SERVER, STUN_PORT))
# Receive response
data, addr = sock.recvfrom(1024)
# Parse response
public_ip, public_port = parse_stun_response(data, transaction_id)
return public_ip, public_port
finally:
sock.close()
# Usage
public_ip, public_port = get_public_ip_port()
print(f"Public IP: {public_ip}:{public_port}")
JavaScript (WebRTC) Example
// Create RTCPeerConnection with STUN server
const configuration = {
iceServers: [
{ urls: 'stun:stun.l.google.com:19302' },
{ urls: 'stun:stun1.l.google.com:19302' }
]
};
const pc = new RTCPeerConnection(configuration);
// Listen for ICE candidates
pc.onicecandidate = (event) => {
if (event.candidate) {
console.log('ICE Candidate:', event.candidate);
// Send candidate to remote peer via signaling
}
};
// Create offer to trigger ICE gathering
pc.createOffer()
.then(offer => pc.setLocalDescription(offer))
.then(() => {
// ICE candidates will be gathered
// and onicecandidate will be called
});
Public STUN Servers
Free STUN Servers
Google:
stun.l.google.com:19302
stun1.l.google.com:19302
stun2.l.google.com:19302
stun3.l.google.com:19302
stun4.l.google.com:19302
Twilio:
global.stun.twilio.com:3478
OpenRelay:
stun.relay.metered.ca:80
Testing STUN Server
# Using stunclient (stuntman tools)
stunclient stun.l.google.com
# Output example:
# Binding test: success
# Local address: 192.168.1.10:45678
# Mapped address: 203.0.113.5:45678
STUN Limitations
1. Doesn't Work with Symmetric NAT
STUN tells you: 203.0.113.5:6000
But when connecting to peer, NAT assigns: 203.0.113.5:6001
Peer can't connect to you
→ Need TURN relay
2. Requires UDP
Some networks block UDP
STUN won't work
→ Need TCP fallback or TURN over TCP
3. Firewall Issues
Restrictive firewalls may block P2P connections
Even with correct IP:port from STUN
→ Need TURN relay
4. No Data Relay
STUN only discovers address
Doesn't relay data
If direct connection fails, need TURN
STUN vs TURN vs ICE
STUN:
- Discovers public IP:port
- Lightweight
- No bandwidth cost
- Doesn't always work
TURN:
- Relays traffic
- Always works
- Bandwidth intensive
- Costs money
ICE:
- Uses both STUN and TURN
- Tries STUN first
- Falls back to TURN
- Best of both worlds
STUN Server Setup
Using coturn
# Install
sudo apt-get install coturn
# Configure /etc/turnserver.conf
listening-port=3478
fingerprint
lt-cred-mech
use-auth-secret
static-auth-secret=YOUR_SECRET
realm=example.com
total-quota=100
stale-nonce=600
Run STUN Server
# Start server
sudo turnserver -v
# Test locally
stunclient localhost
ELI10
STUN is like asking a friend "What's my address?" when you can't see it yourself:
The Problem: You live in an apartment building (NAT) Someone outside wants to send you mail They need your full address, not just "Apartment 5"
STUN Solution:
- You call a friend outside (STUN server)
- Friend says: "I see your address as 123 Main St, Apartment 5"
- You tell pen pal: "Send letters to 123 Main St, Apt 5"
- Pen pal can now reach you!
NAT Types:
- Full Cone: Anyone can mail you once you have the address
- Restricted: Only people you mailed first can mail back
- Symmetric: Building assigns different box for each sender (hard!)
When STUN Doesn't Work:
- Symmetric NAT: Address changes for each recipient
- Firewall: Building doesn't accept outside mail
- → Need TURN (a forwarding service)
WebRTC Uses STUN:
- Video calls discover how to reach each other
- Try direct connection first (with STUN)
- Use relay (TURN) if direct doesn't work
Further Resources
- RFC 5389 - STUN Specification
- RFC 8489 - STUN Update
- WebRTC and STUN
- Interactive STUN Test
- coturn Server
TURN (Traversal Using Relays around NAT)
Overview
TURN is a protocol that helps establish connections between peers when direct peer-to-peer communication fails. Unlike STUN which only discovers addresses, TURN acts as a relay server that forwards traffic between peers when NAT or firewall restrictions prevent direct connections.
Why TURN is Needed
When STUN Fails
Scenario 1: Symmetric NAT
Peer A behind Symmetric NAT
Different public port for each destination
STUN can't provide usable address
→ Need TURN relay
Scenario 2: Restrictive Firewall
Corporate firewall blocks incoming P2P
Even with correct address from STUN
→ Need TURN relay
Scenario 3: UDP Blocked
Network blocks UDP traffic
Can't use STUN or direct P2P
→ Need TURN over TCP
TURN vs STUN
| Feature | STUN | TURN |
|---|---|---|
| Purpose | Discover public address | Relay traffic |
| Bandwidth | Minimal (discovery only) | High (relays all data) |
| Success Rate | ~80% | ~100% |
| Cost | Free (public servers) | Expensive (bandwidth) |
| Latency | Low (direct connection) | Higher (via relay) |
| When to Use | First attempt | Fallback |
TURN Architecture
Basic Relay
Peer A TURN Server Peer B
(Behind NAT) (Public IP) (Behind NAT)
192.168.1.10 198.51.100.1 10.0.0.5
| | |
| Allocate Request | |
|--------------------------->| |
| Allocate Success | |
|<---------------------------| |
| (Relayed address assigned) | |
| | |
| Send relayed address | |
| to Peer B via signaling | |
| | |
| Data | Data |
|--------------------------->|------------------------>|
| | (TURN relays) |
| Data | Data |
|<---------------------------|<------------------------|
Allocation
Client requests allocation from TURN server:
1. Client: "I need a relay address"
2. TURN: "Here's 198.51.100.1:50000 for you"
3. Client: "Route traffic between me and Peer X"
4. TURN: "OK, I'll relay your traffic"
Allocation lifetime: 10 minutes (default, can be refreshed)
TURN Message Types
Key Operations
| Operation | Description |
|---|---|
| Allocate | Request relay address |
| Refresh | Extend allocation lifetime |
| Send | Send data through relay |
| Data | Receive data from relay |
| CreatePermission | Allow peer to send data |
| ChannelBind | Optimize data transfer |
Allocate Request/Response
Request:
Client → TURN Server
Method: Allocate
Attributes:
REQUESTED-TRANSPORT: UDP (17)
LIFETIME: 600 seconds
USERNAME: "alice"
MESSAGE-INTEGRITY: HMAC
Response:
TURN Server → Client
Method: Allocate Success
Attributes:
XOR-RELAYED-ADDRESS: 198.51.100.1:50000
LIFETIME: 600 seconds
XOR-MAPPED-ADDRESS: 203.0.113.5:54321 (client's public IP)
MESSAGE-INTEGRITY: HMAC
TURN Workflow
1. Allocation
Client TURN Server
| |
| Allocate Request |
| (credentials, transport) |
|------------------------------->|
| |
| Allocate Success |
| (relayed address) |
|<-------------------------------|
| |
Allocation created:
Client: 203.0.113.5:54321
Relay: 198.51.100.1:50000
Lifetime: 600 seconds
2. Permission
Client TURN Server
| |
| CreatePermission Request |
| (peer IP: 10.0.0.5) |
|------------------------------->|
| |
| CreatePermission Success |
|<-------------------------------|
| |
TURN server now accepts traffic from 10.0.0.5
Permission lifetime: 300 seconds
3. Sending Data
Method A: Send Indication
Client TURN Server Peer
| | |
| Send Indication | |
| To: 10.0.0.5:6000 | |
| Data: "Hello" | |
|--------------------------->| |
| | UDP: "Hello" |
| |--------------------->|
| | |
Method B: Channel Binding (Optimized)
Client TURN Server Peer
| | |
| ChannelBind Request | |
| Channel: 0x4000 | |
| Peer: 10.0.0.5:6000 | |
|--------------------------->| |
| | |
| ChannelBind Success | |
|<---------------------------| |
| | |
| ChannelData (0x4000) | |
| Data: "Hello" | |
|--------------------------->| |
| | UDP: "Hello" |
| |--------------------->|
ChannelData has only 4-byte overhead (vs 36 bytes for Send)
More efficient for continuous data flow
4. Receiving Data
Peer TURN Server Client
| | |
| UDP: "Reply" | |
|--------------------------->| |
| | Data Indication |
| | From: 10.0.0.5:6000 |
| | Data: "Reply" |
| |--------------------->|
| | |
TURN Message Format
Send Indication
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|0 0| STUN Message Type | Message Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Magic Cookie |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Transaction ID (96 bits) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| XOR-PEER-ADDRESS |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| DATA |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
ChannelData Message
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Channel Number | Length |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| Application Data |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Channel Number: 0x4000 - 0x7FFF
Length: Length of application data
TURN Attributes
Common Attributes
| Attribute | Type | Description |
|---|---|---|
| XOR-RELAYED-ADDRESS | 0x0016 | Relay transport address |
| XOR-PEER-ADDRESS | 0x0012 | Peer transport address |
| DATA | 0x0013 | Data to relay |
| LIFETIME | 0x000D | Allocation lifetime (seconds) |
| REQUESTED-TRANSPORT | 0x0019 | Desired transport (UDP=17) |
| CHANNEL-NUMBER | 0x000C | Channel number (0x4000-0x7FFF) |
| EVEN-PORT | 0x0018 | Request even port (RTP/RTCP) |
| DONT-FRAGMENT | 0x001A | Don't fragment |
| RESERVATION-TOKEN | 0x0022 | Token for port reservation |
TURN Authentication
Long-Term Credentials
Request 1 (no credentials):
Allocate Request
Response 1:
Error 401 Unauthorized
REALM: "example.com"
NONCE: "abcd1234"
Request 2 (with credentials):
USERNAME: "alice"
REALM: "example.com"
NONCE: "abcd1234"
MESSAGE-INTEGRITY: HMAC-SHA1(message, key)
Key = MD5(username:realm:password)
Response 2:
Allocate Success Response
XOR-RELAYED-ADDRESS: ...
LIFETIME: 600
Short-Term Credentials
Used within ICE (WebRTC):
USERNAME: <random>:<random>
PASSWORD: <shared secret>
Simpler, time-limited authentication
TURN Allocation Lifecycle
1. Allocate (request relay)
↓
2. Success (relay assigned)
↓
3. CreatePermission (allow peers)
↓
4. ChannelBind (optimize transfer)
↓
5. Send/Receive Data
↓
6. Refresh (extend lifetime)
↓
7. Delete or Expire
Timeline:
0s: Allocate
300s: Refresh (extend to 900s)
600s: Refresh (extend to 1200s)
900s: Refresh (extend to 1500s)
...
Stop refreshing: Allocation expires
TURN Over Different Transports
TURN over UDP
Default mode
Client → TURN Server: UDP
TURN Server → Peer: UDP
Fast, but UDP might be blocked
TURN over TCP
Client → TURN Server: TCP
TURN Server → Peer: UDP
Works when UDP blocked
More overhead (TCP vs UDP)
TURN over TLS
Client → TURN Server: TLS over TCP
TURN Server → Peer: UDP
Encrypted control channel
Works in restrictive environments
Port 443 (looks like HTTPS)
ICE with TURN
Candidate Priority
ICE tries candidates in order:
1. Host Candidate (local IP)
Type: host
Priority: Highest
Example: 192.168.1.10:5000
2. Server Reflexive (STUN)
Type: srflx
Priority: High
Example: 203.0.113.5:6000
3. Relayed (TURN)
Type: relay
Priority: Lowest (fallback)
Example: 198.51.100.1:50000
Connection attempt:
Try host → Try srflx → Try relay
Use first successful connection
WebRTC with TURN
const configuration = {
iceServers: [
// STUN server (free)
{ urls: 'stun:stun.l.google.com:19302' },
// TURN server (requires auth)
{
urls: 'turn:turn.example.com:3478',
username: 'alice',
credential: 'password123'
},
// TURN over TLS
{
urls: 'turns:turn.example.com:5349',
username: 'alice',
credential: 'password123'
}
]
};
const pc = new RTCPeerConnection(configuration);
TURN Server Setup
Using coturn
Install:
sudo apt-get install coturn
Configure /etc/turnserver.conf:
# Listening ports
listening-port=3478
tls-listening-port=5349
# Relay ports
min-port=49152
max-port=65535
# Authentication
lt-cred-mech
user=alice:password123
realm=example.com
# Or use shared secret
use-auth-secret
static-auth-secret=my-secret-key
# Certificates (for TLS)
cert=/etc/ssl/turn.crt
pkey=/etc/ssl/turn.key
# Logging
log-file=/var/log/turnserver.log
verbose
# External IP (if behind NAT)
external-ip=203.0.113.5/192.168.1.10
# Limit resources
max-bps=1000000
total-quota=100
Run:
sudo turnserver -v
Test:
# Using turnutils
turnutils_uclient -v -u alice -w password123 turn.example.com
TURN Bandwidth Considerations
Bandwidth Usage
Video call: 2 Mbps per direction
Direct P2P (no TURN):
Client A →→ Client B
Total bandwidth: 4 Mbps (2 up + 2 down each)
Through TURN relay:
Client A → TURN → Client B
TURN bandwidth: 4 Mbps (2 in + 2 out)
Each client: 4 Mbps (2 up + 2 down)
TURN server needs 2x the bandwidth!
Cost Implications
Example: 1000 concurrent video calls through TURN
Each call: 2 Mbps × 2 directions = 4 Mbps
Total: 1000 × 4 Mbps = 4 Gbps
At $0.10/GB:
4 Gbps = 0.5 GB/second
Per hour: 1,800 GB = $180/hour
Per day: 43,200 GB = $4,320/day
Why ICE tries direct connection first!
Optimization Strategies
1. Prefer direct connections (STUN)
- ~80% of connections succeed
- Zero relay bandwidth
2. Short allocation lifetimes
- Free up resources quickly
- Prevent unused allocations
3. Connection quality monitoring
- Switch from relay to direct if possible
- ICE restart
4. Rate limiting
- Prevent abuse
- Fair resource sharing
5. Geographic distribution
- Regional TURN servers
- Reduce latency
TURN Security
Authentication Required
Public TURN servers = expensive bandwidth
Must authenticate:
- Username/password
- Time-limited credentials
- Shared secrets
Quota Management
Limit per user:
- Bandwidth (bytes/sec)
- Total data (GB)
- Concurrent allocations
- Allocation lifetime
Access Control
Restrict by:
- IP ranges (corporate network)
- Time windows
- User groups
Monitoring TURN Server
Key Metrics
1. Active allocations
- Current number
- Peak usage
2. Bandwidth
- Total throughput
- Per-client usage
- Inbound/outbound ratio
3. Connections
- Success rate
- Allocation duration
- Peak concurrent
4. Authentication
- Failed attempts
- Expired credentials
5. Resources
- CPU usage
- Memory
- Network interfaces
- Port exhaustion
coturn Statistics
# Real-time stats
telnet localhost 5766
# Commands:
ps # Print sessions
pid # Show process info
pc # Print configuration
TURN Alternatives
1. Direct P2P (preferred)
Pros: Free, low latency
Cons: Doesn't always work
Success rate: ~80%
2. SIP/VoIP Gateways
Traditional VoIP infrastructure
Built-in media relays
More expensive
3. Media Servers
Janus, Jitsi, Kurento
Selective Forwarding Unit (SFU)
Different model than TURN
Troubleshooting TURN
Can't allocate
# Check TURN server is running
sudo systemctl status coturn
# Check listening ports
netstat -tuln | grep 3478
# Test with turnutils
turnutils_uclient -v turn.example.com
Authentication fails
# Verify credentials
turnutils_uclient -u alice -w password123 turn.example.com
# Check realm configuration
grep realm /etc/turnserver.conf
# Check logs
tail -f /var/log/turnserver.log
High latency
- Use geographically closer TURN server
- Check server load (CPU, bandwidth)
- Try TURN over TCP (sometimes faster)
- Monitor network path (traceroute)
ELI10
TURN is like using a friend to pass notes in class:
Without TURN (Direct):
- You throw note directly to friend
- Fast and easy
- But teacher might catch it!
With TURN (Through Relay):
- You give note to trusted student
- They walk it over to your friend
- Slower, but always works
- Even if teacher is watching
Why TURN?
Imagine you're in Building A, friend in Building B:
- Can't throw note that far (NAT/firewall blocking)
- Need someone in the middle to help
- TURN server is that helpful person
Costs:
- Direct (free): Just toss the note
- TURN (expensive): Someone must carry every note back and forth
- Video call = thousands of notes per second!
- TURN server gets tired (bandwidth costs)
Smart Strategy (ICE):
- Try throwing directly (host candidate)
- Try from outside (STUN)
- Last resort: Use TURN relay
Use TURN only when absolutely needed!
Further Resources
ICE (Interactive Connectivity Establishment)
Overview
ICE (Interactive Connectivity Establishment) is a framework used to establish peer-to-peer connections through NATs and firewalls. It's primarily used by WebRTC, VoIP applications, and other real-time communication systems to find the best path for connecting two endpoints on the internet.
The NAT Problem
Why ICE is Needed
Traditional Scenario:
┌──────────────┐ ┌──────────────┐
│ Client A │ │ Client B │
│ 10.0.0.5 │ │ 192.168.1.10 │
└──────┬───────┘ └──────┬───────┘
│ │
│ NAT NAT │
│ │
┌──────▼───────┐ ┌──────▼───────┐
│ Router │ │ Router │
│ 203.0.113.5 │ │ 198.51.100.3 │
└──────────────┘ └──────────────┘
│ │
└───────────── Internet ────────────┘
Problems:
1. Client A only knows its private IP (10.0.0.5)
2. Client B can't reach 10.0.0.5 (not routable)
3. Client A doesn't know Client B's public IP
4. Routers block unsolicited incoming connections
ICE Solution:
1. Discover public IPs (STUN)
2. Try multiple connection paths
3. Use relay as fallback (TURN)
4. Select best working path
How ICE Works
The ICE Process
1. Gather Candidates
Collect all possible ways to reach this peer:
- Host candidate (local IP)
- Server reflexive (public IP from STUN)
- Relayed candidate (TURN relay)
2. Exchange Candidates
Send candidates to remote peer via signaling
3. Pair Candidates
Create pairs: local candidate + remote candidate
4. Check Connectivity
Test all pairs in priority order
5. Select Best Pair
Use the pair with highest priority that works
6. Keep Alive
Maintain selected connection
Detailed Flow Diagram
Peer A Signaling Server Peer B
| | |
|──① Gather Candidates | |
| - host | |
| - srflx (STUN) | |
| - relay (TURN) | |
| | |
|──② Send Candidates──────────►| |
| via SDP offer | |
| |──③ Forward Candidates────────►|
| | |
| | |──④ Gather Candidates
| | | - host
| | | - srflx (STUN)
| | | - relay (TURN)
| | |
| |◄──⑤ Send Candidates───────────|
| | via SDP answer |
|◄─⑥ Forward Candidates────────| |
| | |
|──⑦ Connectivity Checks───────────────────────────────────────►|
| (test all candidate pairs) |
| | |
|◄─⑧ Connectivity Checks───────────────────────────────────────|
| (test all candidate pairs) |
| | |
|──⑨ Nomination (best pair)────────────────────────────────────►|
| | |
|◄─⑩ Confirmation──────────────────────────────────────────────|
| | |
|══⑪ Media/Data Flow ═══════════════════════════════════════════|
| (using selected pair) |
ICE Candidate Types
1. Host Candidate
Local network interface address:
Type: host
Address: 10.0.0.5:54321
Foundation: 1
Characteristics:
- Actual IP address of the interface
- No NAT traversal
- Works only on same local network
- Lowest latency
- Priority: High for local connections
Example:
candidate:1 1 UDP 2130706431 10.0.0.5 54321 typ host
Use case:
- Devices on same LAN
- No NAT between peers
2. Server Reflexive Candidate (srflx)
Public IP address as seen by STUN server:
Type: srflx
Address: 203.0.113.5:61234
Related: 10.0.0.5:54321
Foundation: 2
Characteristics:
- Discovered via STUN server
- Public IP:port after NAT
- Most common for internet connections
- Priority: Medium-High
Example:
candidate:2 1 UDP 1694498815 203.0.113.5 61234 typ srflx
raddr 10.0.0.5 rport 54321
Discovery:
1. Client sends STUN request from 10.0.0.5:54321
2. STUN server sees request from 203.0.113.5:61234
3. STUN responds with "Your IP:port is 203.0.113.5:61234"
4. Client creates srflx candidate
Use case:
- Typical internet connections
- NAT traversal
- Peer-to-peer over internet
3. Peer Reflexive Candidate (prflx)
Public IP discovered during connectivity checks:
Type: prflx
Address: 203.0.113.5:61235
Foundation: 3
Characteristics:
- Discovered during checks (not via STUN)
- Learned from peer's connectivity checks
- Alternative to srflx
- Priority: Medium
Example:
candidate:3 1 UDP 1862270975 203.0.113.5 61235 typ prflx
raddr 10.0.0.5 rport 54321
Discovery:
1. Peer B sends connectivity check
2. Peer A receives from unexpected address
3. Peer A learns new reflexive address
4. Creates prflx candidate
Use case:
- Discovered during connection attempts
- Additional connectivity options
4. Relayed Candidate (relay)
Address on TURN relay server:
Type: relay
Address: 198.51.100.10:55555
Related: 203.0.113.5:61234
Foundation: 4
Characteristics:
- Allocated on TURN server
- Relay forwards all traffic
- Works through any NAT/firewall
- Highest latency and bandwidth cost
- Priority: Low (fallback)
Example:
candidate:4 1 UDP 16777215 198.51.100.10 55555 typ relay
raddr 203.0.113.5 rport 61234
Discovery:
1. Client requests allocation from TURN server
2. TURN allocates 198.51.100.10:55555
3. Client creates relay candidate
4. All traffic flows through TURN
Use case:
- Symmetric NATs
- Restrictive firewalls
- When direct connection fails
- Corporate networks
Candidate Priority
Priority Calculation
Priority Formula:
priority = (2^24 × type preference) +
(2^8 × local preference) +
(256 - component ID)
Type Preference (higher = better):
- host: 126
- prflx: 110
- srflx: 100
- relay: 0
Local Preference:
- Higher for interfaces you prefer
- Typically: 65535 for best interface
Component ID:
- 1 for RTP (main media)
- 2 for RTCP (control)
Example Calculations:
Host candidate:
(2^24 × 126) + (2^8 × 65535) + (256 - 1)
= 2113667071
Srflx candidate:
(2^24 × 100) + (2^8 × 65535) + (256 - 1)
= 1694498815
Relay candidate:
(2^24 × 0) + (2^8 × 65535) + (256 - 1)
= 16777215
Priority in Practice
Sorted by priority (high to low):
1. host (LAN) Priority: 2113667071
- Try first
- Works if same network
- Lowest latency
2. srflx (NAT) Priority: 1694498815
- Try second
- Works through NAT
- Good latency
3. prflx (Discovered) Priority: 1862270975
- Try if discovered
- Alternative path
4. relay (TURN) Priority: 16777215
- Try last
- Always works
- Higher latency/cost
Best path:
host → host (LAN)
host → srflx (NAT traversal)
srflx → srflx (Both behind NAT)
relay → relay (Fallback)
Candidate Gathering
ICE Gathering States
// ICE gathering state machine
peerConnection.onicegatheringstatechange = () => {
console.log('ICE gathering state:',
peerConnection.iceGatheringState);
};
/*
States:
1. new
- Initial state
- No gathering started
2. gathering
- Actively gathering candidates
- STUN/TURN requests in progress
3. complete
- All candidates gathered
- Ready to connect
*/
// Monitor gathering
peerConnection.addEventListener('icecandidate', (event) => {
if (event.candidate) {
console.log('New candidate:', event.candidate);
// Send to remote peer
} else {
console.log('Gathering complete');
// All candidates collected
}
});
Gathering Configuration
// Configure ICE servers
const configuration = {
iceServers: [
// Public STUN servers (Google)
{
urls: 'stun:stun.l.google.com:19302'
},
{
urls: 'stun:stun1.l.google.com:19302'
},
// STUN server (custom)
{
urls: 'stun:stun.example.com:3478'
},
// TURN server (UDP and TCP)
{
urls: [
'turn:turn.example.com:3478',
'turn:turn.example.com:3478?transport=tcp'
],
username: 'user',
credential: 'password',
credentialType: 'password'
},
// TURN server (TLS)
{
urls: 'turns:turn.example.com:5349',
username: 'user',
credential: 'password'
}
],
// ICE transport policy
iceTransportPolicy: 'all', // 'all' or 'relay'
// 'all': Try all candidates
// 'relay': Only use TURN (force relay)
// Candidate pool size
iceCandidatePoolSize: 10
// Pre-allocate TURN allocations
// Higher = faster but more resources
};
const peerConnection = new RTCPeerConnection(configuration);
Trickle ICE
Instead of waiting for all candidates, send them as discovered:
// Sender: Send candidates as discovered
peerConnection.onicecandidate = (event) => {
if (event.candidate) {
// Send immediately (trickle)
signaling.send({
type: 'ice-candidate',
candidate: event.candidate
});
} else {
// Signal end of candidates
signaling.send({
type: 'ice-candidate',
candidate: null
});
}
};
// Receiver: Add candidates as received
signaling.on('ice-candidate', async (message) => {
if (message.candidate) {
try {
await peerConnection.addIceCandidate(
new RTCIceCandidate(message.candidate)
);
} catch (error) {
console.error('Error adding candidate:', error);
}
} else {
// End of candidates
console.log('Remote candidate gathering complete');
}
});
Benefits:
- Faster connection establishment
- Start checks before all candidates gathered
- Improved user experience
Connectivity Checks
STUN Binding Requests
ICE uses STUN messages to test connectivity:
Connectivity Check Process:
1. Create Candidate Pairs
Local Candidate Remote Candidate Pair
10.0.0.5:54321 + 192.168.1.10:44444 = Pair 1
10.0.0.5:54321 + 198.51.100.3:55555 = Pair 2
203.0.113.5:61234 + 192.168.1.10:44444 = Pair 3
203.0.113.5:61234 + 198.51.100.3:55555 = Pair 4
198.51.100.10:55555 + 192.168.1.10:44444 = Pair 5
198.51.100.10:55555 + 198.51.100.3:55555 = Pair 6
2. Sort Pairs by Priority
Priority = min(local priority, remote priority)
3. Send STUN Binding Request
From: Local candidate
To: Remote candidate
Message: STUN Binding Request
Attributes:
- USERNAME: ice-ufrag
- PRIORITY: candidate priority
- ICE-CONTROLLING or ICE-CONTROLLED
- MESSAGE-INTEGRITY: HMAC
4. Receive STUN Binding Response
From: Remote candidate
To: Local candidate
Message: STUN Binding Response (Success)
Attributes:
- XOR-MAPPED-ADDRESS
- MESSAGE-INTEGRITY
5. Mark Pair as Valid
If response received, pair works!
6. Nominate Best Pair
Controlling agent nominates highest priority valid pair
Controlling vs Controlled
ICE Roles:
Controlling Agent (Caller):
- Makes final decision on selected pair
- Sends nomination
- Usually the offerer
Controlled Agent (Callee):
- Responds to checks
- Accepts nomination
- Usually the answerer
Role Conflict Resolution:
If both think they're controlling:
- Compare ICE tie-breaker values
- Higher value becomes controlling
- Lower value becomes controlled
Attribute:
ICE-CONTROLLING: <tie-breaker>
or
ICE-CONTROLLED: <tie-breaker>
Connectivity Check States
Candidate Pair States:
1. Frozen
- Initial state
- Waiting to be checked
- Not yet sent binding request
2. Waiting
- Ready to check
- Will check soon
- Waiting for resources
3. In Progress
- Binding request sent
- Waiting for response
- Timeout if no response
4. Succeeded
- Binding response received
- Pair is valid
- Can be used for media
5. Failed
- No response (timeout)
- Or error response
- Cannot use this pair
State Machine:
Frozen → Waiting → In Progress → Succeeded ✓
→ Failed ✗
Connection States
ICE Connection States
peerConnection.oniceconnectionstatechange = () => {
console.log('ICE connection state:',
peerConnection.iceConnectionState);
switch (peerConnection.iceConnectionState) {
case 'new':
// Initial state, gathering not started
console.log('ICE gathering starting...');
break;
case 'checking':
// Checking candidate pairs
console.log('Testing connectivity...');
break;
case 'connected':
// At least one working pair found
console.log('Connection established!');
break;
case 'completed':
// All checks done, best pair selected
console.log('ICE completed');
break;
case 'failed':
// All pairs failed
console.error('Connection failed');
// Fallback: restart ICE or use TURN
handleConnectionFailure();
break;
case 'disconnected':
// Lost connectivity (temporary?)
console.warn('Connection lost, attempting to reconnect...');
break;
case 'closed':
// Connection closed
console.log('Connection closed');
break;
}
};
// Overall connection state (combines ICE + DTLS)
peerConnection.onconnectionstatechange = () => {
console.log('Connection state:',
peerConnection.connectionState);
// States: new, connecting, connected, disconnected, failed, closed
};
ICE Restart
When connection fails or degrades:
// Restart ICE
async function restartIce(peerConnection) {
console.log('Restarting ICE...');
// Create new offer with iceRestart option
const offer = await peerConnection.createOffer({
iceRestart: true
});
await peerConnection.setLocalDescription(offer);
// Send new offer to peer
signaling.send({
type: 'offer',
sdp: offer
});
// New candidates will be gathered
// New connectivity checks will be performed
}
// Trigger restart on failure
peerConnection.oniceconnectionstatechange = () => {
if (peerConnection.iceConnectionState === 'failed') {
console.error('ICE failed, restarting...');
restartIce(peerConnection);
}
};
// Or restart on disconnection timeout
let disconnectTimeout;
peerConnection.oniceconnectionstatechange = () => {
if (peerConnection.iceConnectionState === 'disconnected') {
// Wait 5 seconds before restart
disconnectTimeout = setTimeout(() => {
if (peerConnection.iceConnectionState !== 'connected') {
restartIce(peerConnection);
}
}, 5000);
} else if (peerConnection.iceConnectionState === 'connected') {
clearTimeout(disconnectTimeout);
}
};
Advanced ICE Features
ICE Lite
Simplified ICE for servers:
ICE Lite:
- Only responds to checks (doesn't initiate)
- Only gathers host candidates
- Simpler implementation
- Used by servers (not browsers)
Standard ICE vs ICE Lite:
Standard ICE (Full Agent):
- Gathers all candidate types
- Sends connectivity checks
- Can be controlling or controlled
- Used by clients
ICE Lite:
- Only host candidates
- Only responds to checks
- Always controlled role
- Used by servers
Example: Media server
- Server uses ICE Lite
- Client uses full ICE
- Client initiates all checks
- Server just responds
Consent Freshness
Keep-alive mechanism:
Purpose:
- Verify peer still wants to receive
- Detect path changes
- Prevent unwanted traffic
Process:
1. Every 5 seconds, send STUN Binding Request
2. Peer responds with Binding Response
3. If no response for 30 seconds → disconnected
STUN Binding Request:
- From selected local candidate
- To selected remote candidate
- Authenticated (MESSAGE-INTEGRITY)
Failure:
- 30 seconds without response
- ICE state → disconnected
- May trigger ICE restart
Automatic in WebRTC:
- Browser handles automatically
- No manual intervention needed
Aggressive Nomination
Faster connection establishment:
Regular Nomination:
1. Check all pairs
2. Wait for all checks to complete
3. Nominate best pair
Time: Slow but optimal
Aggressive Nomination:
1. Check pairs in priority order
2. Nominate first working pair immediately
3. Continue checking in background
Time: Fast but may not be optimal
Trade-off:
- Aggressive: Faster connection, may not be best path
- Regular: Slower connection, guaranteed best path
Most WebRTC implementations use regular nomination
for better quality.
Debugging ICE
Analyzing ICE Candidates
// Log all candidates
peerConnection.onicecandidate = (event) => {
if (event.candidate) {
const candidate = event.candidate.candidate;
console.log('Candidate:', candidate);
// Parse candidate
const parts = candidate.split(' ');
const parsed = {
foundation: parts[0].split(':')[1],
component: parts[1],
protocol: parts[2],
priority: parts[3],
ip: parts[4],
port: parts[5],
type: parts[7],
relAddr: parts[9],
relPort: parts[11]
};
console.log('Parsed:', parsed);
// Identify issues
if (parsed.type === 'relay') {
console.warn('Using TURN relay (may indicate NAT/firewall issues)');
}
if (parsed.protocol === 'tcp') {
console.warn('Using TCP (UDP may be blocked)');
}
}
};
// Monitor selected pair
async function getSelectedPair(peerConnection) {
const stats = await peerConnection.getStats();
stats.forEach(report => {
if (report.type === 'candidate-pair' && report.state === 'succeeded') {
console.log('Selected pair:');
console.log(' Local:', report.localCandidateId);
console.log(' Remote:', report.remoteCandidateId);
console.log(' State:', report.state);
console.log(' Priority:', report.priority);
console.log(' RTT:', report.currentRoundTripTime);
console.log(' Bytes sent:', report.bytesSent);
console.log(' Bytes received:', report.bytesReceived);
}
});
}
// Check every second
setInterval(() => getSelectedPair(peerConnection), 1000);
Common ICE Issues
Issue: No candidates gathered
Cause: Missing or incorrect STUN/TURN config
Solution: Verify iceServers configuration
Issue: Only relay candidates
Cause: Restrictive firewall blocks UDP
Solution:
- Enable UDP ports
- Use TURN with TCP
- Check firewall rules
Issue: Connectivity checks fail
Cause: Firewall blocks STUN packets
Solution:
- Allow UDP 3478 (STUN)
- Allow UDP 49152-65535 (RTP)
- Use TURN as fallback
Issue: Connection works then fails
Cause: NAT binding timeout
Solution:
- Shorter keep-alive interval
- Use consent freshness
- ICE restart on failure
Issue: High latency
Cause: Using TURN relay when direct possible
Solution:
- Verify STUN server reachable
- Check NAT type (symmetric NAT requires TURN)
- Verify candidate priorities
Issue: One-way media
Cause: Asymmetric connectivity
Solution:
- Check firewall rules both directions
- Verify both peers send candidates
- Use TURN if necessary
ICE Testing Tools
# Test STUN server
stunclient stun.l.google.com 19302
# Output shows:
# - Your public IP
# - NAT type
# - Whether server is reachable
# Test TURN server
turnutils_uclient -v -u username -w password \
turn.example.com
# Test with ICE
# Browser: chrome://webrtc-internals
# - View all ICE candidates
# - See connectivity checks
# - Monitor selected pair
# Command-line ICE test
npm install -g wrtc-ice-tester
wrtc-ice-tester --stun stun:stun.l.google.com:19302
# Network debugging
tcpdump -i any -n udp port 3478 or portrange 49152-65535
# WebRTC test page
https://test.webrtc.org/
ICE Configuration Examples
Basic Configuration
// Minimal config (STUN only)
const config = {
iceServers: [
{ urls: 'stun:stun.l.google.com:19302' }
]
};
// With TURN fallback
const config = {
iceServers: [
{ urls: 'stun:stun.l.google.com:19302' },
{
urls: 'turn:turn.example.com:3478',
username: 'user',
credential: 'pass'
}
]
};
// Production config (redundancy)
const config = {
iceServers: [
// Multiple STUN servers
{ urls: 'stun:stun1.example.com:3478' },
{ urls: 'stun:stun2.example.com:3478' },
// TURN with TCP fallback
{
urls: [
'turn:turn.example.com:3478', // UDP
'turn:turn.example.com:3478?transport=tcp', // TCP
'turns:turn.example.com:5349' // TLS
],
username: 'user',
credential: 'pass'
}
],
iceCandidatePoolSize: 10,
iceTransportPolicy: 'all' // Try everything
};
Dynamic TURN Credentials
// Get temporary TURN credentials from your server
async function getTurnCredentials() {
const response = await fetch('/api/turn-credentials', {
headers: { 'Authorization': 'Bearer ' + token }
});
return await response.json();
/*
Returns:
{
urls: ['turn:turn.example.com:3478'],
username: 'temporary-user-12345',
credential: 'temporary-password',
ttl: 86400 // 24 hours
}
*/
}
// Use dynamic credentials
const turnCreds = await getTurnCredentials();
const config = {
iceServers: [
{ urls: 'stun:stun.example.com:3478' },
{
urls: turnCreds.urls,
username: turnCreds.username,
credential: turnCreds.credential,
credentialType: 'password'
}
]
};
const pc = new RTCPeerConnection(config);
Server-Side TURN Credential Generation
// Node.js server
const crypto = require('crypto');
function generateTurnCredentials(username, secret, ttl = 86400) {
const timestamp = Math.floor(Date.now() / 1000) + ttl;
const turnUsername = `${timestamp}:${username}`;
const hmac = crypto.createHmac('sha1', secret);
hmac.update(turnUsername);
const turnPassword = hmac.digest('base64');
return {
urls: [
'turn:turn.example.com:3478',
'turn:turn.example.com:3478?transport=tcp',
'turns:turn.example.com:5349'
],
username: turnUsername,
credential: turnPassword,
ttl: ttl
};
}
// API endpoint
app.get('/api/turn-credentials', authenticate, (req, res) => {
const credentials = generateTurnCredentials(
req.user.id,
process.env.TURN_SECRET,
86400 // 24 hours
);
res.json(credentials);
});
Performance Considerations
Minimizing Connection Time
// 1. Pre-gather candidates
const pc = new RTCPeerConnection({
iceServers: [...],
iceCandidatePoolSize: 10 // Pre-allocate TURN
});
// 2. Use trickle ICE (send candidates immediately)
pc.onicecandidate = (event) => {
if (event.candidate) {
signaling.send({ type: 'candidate', candidate: event.candidate });
}
};
// 3. Start gathering early
await pc.setLocalDescription(await pc.createOffer());
// 4. Use multiple STUN servers (parallel queries)
const config = {
iceServers: [
{ urls: 'stun:stun1.example.com:3478' },
{ urls: 'stun:stun2.example.com:3478' },
{ urls: 'stun:stun3.example.com:3478' }
]
};
// 5. Close old connection before creating new one
if (oldPeerConnection) {
oldPeerConnection.close();
}
// Typical connection time:
// - LAN: 100-500ms
// - Internet (NAT): 1-3 seconds
// - TURN relay: 2-5 seconds
Bandwidth Considerations
TURN Relay Bandwidth:
Scenario: 10 users in video call, all using TURN
Without TURN (P2P mesh):
Each user sends to 9 others directly
Total: 10 × 9 = 90 connections
User bandwidth: 9 video streams (upload + download)
With TURN relay:
Each user → TURN server → other users
Total: 10 × 9 through TURN
TURN bandwidth: 90 video streams
User bandwidth: Same (9 streams)
TURN costs:
- P2P: No relay bandwidth
- TURN: All traffic through server
- Solution: Use TURN only when necessary
Check if using TURN:
const stats = await pc.getStats();
stats.forEach(report => {
if (report.type === 'local-candidate' &&
report.candidateType === 'relay') {
console.warn('Using TURN relay!');
}
});
NAT Types and ICE Success
NAT Type Matrix
NAT Types (restrictiveness):
1. No NAT
✓ Direct connection
Success rate: 100%
2. Full Cone NAT
✓ Any external host can connect
Success rate: 100%
3. Restricted Cone NAT
✓ Can connect after outbound packet
Success rate: 95%
4. Port Restricted Cone NAT
✓ Can connect after outbound to specific port
Success rate: 90%
5. Symmetric NAT
✗ Different port for each destination
Needs TURN relay
Success rate: 100% (with TURN)
Connection Matrix:
Peer B
Full Restricted Symmetric
Peer A
Full ✓ ✓ ✓*
Restricted ✓ ✓ ✓*
Symmetric ✓* ✓* ✗ (need TURN)
✓ = Direct connection (STUN sufficient)
✓* = May need TURN
✗ = Requires TURN relay
ELI10: ICE Explained Simply
ICE is like finding the best way to connect two phones:
The Problem
You: Inside your house (private network)
Friend: Inside their house (private network)
Can't call directly:
- You don't know their full address
- Their house blocks unknown callers
- Your house blocks incoming calls
ICE Solution
1. Find All Your Phone Numbers
- Room extension (host): 101
- House number (srflx): (555) 123-4567
- Call-forwarding service (relay): (555) 999-0000
2. Share Numbers
- You send your 3 numbers to friend
- Friend sends their 3 numbers to you
3. Try All Combinations (9 attempts)
Your 101 → Their 101 (works if same house)
Your 101 → Their (555) 234-5678 (fails)
Your (555) 123-4567 → Their (555) 234-5678 (works!)
... etc
4. Use Best Connection
- Direct if possible (faster, cheaper)
- Through forwarding if necessary (works always)
5. Keep Checking
- "Are you still there?"
- If no answer, try again
Real Terms
House = Private network
House number = Public IP (STUN)
Call forwarding = Relay (TURN)
Trying combinations = Connectivity checks
Best connection = Selected candidate pair
Further Resources
Specifications
Tools
- Trickle ICE - Test ICE candidates
- WebRTC Troubleshooter - Connection testing
- NAT Type Test - Identify NAT type
Debugging
- chrome://webrtc-internals - Chrome ICE debug
- about:webrtc - Firefox ICE debug
STUN/TURN Servers
- Coturn - Open source TURN server
- Xirsys - TURN server hosting
- Twilio STUN/TURN - Managed service
Articles
PCP (Port Control Protocol)
Overview
PCP (Port Control Protocol) is a protocol that allows hosts to control how incoming packets are forwarded by upstream devices such as NAT gateways and firewalls. It's the successor to NAT-PMP and provides more features and flexibility for port mapping and firewall control.
Key Characteristics
Protocol: UDP
Port: 5351
RFC: 6887 (2013)
Predecessor: NAT-PMP (RFC 6886)
Features:
✓ Port mapping (like NAT-PMP)
✓ Firewall control
✓ IPv4 and IPv6 support
✓ Explicit lifetime management
✓ Multiple NATs/firewalls
✓ Third-party port mapping
✓ Failure detection
✓ Security improvements
Why PCP?
Problems with Manual Port Forwarding
Traditional Approach:
1. User logs into router web interface
2. Manually configures port forwarding
3. Must remember to remove when done
4. Doesn't work with multiple NATs
5. Requires user intervention
Problems:
- Not suitable for applications
- Doesn't scale
- Security risk (ports left open)
- Complex for users
PCP Solution
Automated Approach:
1. Application requests port mapping via PCP
2. Router automatically configures forwarding
3. Mapping has lifetime (auto-expires)
4. Application can renew or delete
5. Works with cascaded NATs
Benefits:
✓ Fully automated
✓ Application-controlled
✓ Time-limited (secure)
✓ Works across multiple NATs
✓ Standardized protocol
PCP vs UPnP vs NAT-PMP
Feature PCP UPnP-IGD NAT-PMP
Protocol UDP HTTP/SOAP UDP
Complexity Medium High Low
IPv6 Support Yes Partial No
Multiple NATs Yes No No
Explicit Lifetime Yes No Yes
Firewall Control Yes No No
Third-party Mapping Yes No No
Security Good Weak Basic
Standardization IETF RFC UPnP Forum IETF RFC
Use PCP when:
- Need IPv6 support
- Multiple NATs in path
- Firewall control needed
- Modern deployment
Use NAT-PMP when:
- Simple IPv4 NAT
- Apple ecosystem
- Lightweight solution
Use UPnP when:
- Legacy device support
- Already deployed
- Complex scenarios
PCP Architecture
┌─────────────────────────────────────────────────────────────┐
│ Internet │
└─────────────────┬───────────────────────────────────────────┘
│
┌────────▼────────┐
│ ISP Router │
│ (PCP Server) │
└────────┬────────┘
│
┌────────▼────────┐
│ Home Router │
│ (PCP Server) │ ← Responds to PCP requests
└────────┬────────┘
│
┌────────▼────────┐
│ PCP Client │ ← Sends PCP requests
│ (Application) │
└─────────────────┘
Flow:
1. Client sends PCP request to server
2. Server creates/modifies mapping
3. Server responds with mapping details
4. Client maintains mapping with renewals
5. Mapping expires or client deletes it
PCP Message Format
Request Header
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Version = 2 |R| Opcode | Reserved |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Requested Lifetime (seconds) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| PCP Client's IP Address (128 bits) |
| |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
: Opcode-specific data :
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
: PCP Options (optional) :
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Fields:
Version (8 bits): Protocol version (2)
R (1 bit): 0 for request, 1 for response
Opcode (7 bits):
- 0: ANNOUNCE
- 1: MAP
- 2: PEER
Reserved (16 bits): Must be 0
Requested Lifetime (32 bits): Seconds (0 = delete)
PCP Client IP (128 bits): Client's IP address
- IPv4: ::ffff:a.b.c.d (IPv4-mapped)
- IPv6: Full 128-bit address
Response Header
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Version = 2 |R| Opcode | Reserved | Result Code |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Lifetime (seconds) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Epoch Time (seconds) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Reserved (96 bits) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
: Opcode-specific data :
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
: PCP Options (optional) :
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Result Code:
0: SUCCESS
1: UNSUPP_VERSION
2: NOT_AUTHORIZED
3: MALFORMED_REQUEST
4: UNSUPP_OPCODE
5: UNSUPP_OPTION
6: MALFORMED_OPTION
7: NETWORK_FAILURE
8: NO_RESOURCES
9: UNSUPP_PROTOCOL
10: USER_EX_QUOTA
11: CANNOT_PROVIDE_EXTERNAL
12: ADDRESS_MISMATCH
13: EXCESSIVE_REMOTE_PEERS
Epoch Time:
- Seconds since PCP server started
- Used to detect server reboots
- Client must refresh mappings if changed
PCP Opcodes
MAP Opcode
Create a mapping for inbound traffic:
MAP Request (after common header):
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Mapping Nonce |
| (96 bits) |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Protocol | Reserved (24 bits) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Internal Port | Suggested External Port |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| Suggested External IP Address (128 bits) |
| |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Mapping Nonce:
- Random value to match request/response
- Prevents off-path attacks
Protocol:
- 6 = TCP
- 17 = UDP
- 0 = All protocols
Internal Port:
- Port on PCP client
Suggested External Port:
- Preferred external port
- 0 = server chooses
Suggested External IP:
- Preferred external IP
- 0 = server chooses
MAP Response (after common header):
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Mapping Nonce |
| (96 bits) |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Protocol | Reserved (24 bits) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Internal Port | Assigned External Port |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| Assigned External IP Address (128 bits) |
| |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Server assigns:
- External port (may differ from suggested)
- External IP address
- Lifetime for mapping
PEER Opcode
Create a mapping for bidirectional traffic with a specific peer:
PEER Request (after common header):
Similar to MAP, but includes remote peer address:
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Mapping Nonce |
| (96 bits) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Protocol | Reserved (24 bits) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Internal Port | Suggested External Port |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| Suggested External IP Address (128 bits) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Remote Peer Port | Reserved (16 bits) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| Remote Peer IP Address (128 bits) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Remote Peer Port:
- Port on remote peer
Remote Peer IP:
- IP address of remote peer
Use cases:
- P2P applications
- WebRTC
- VoIP
- Gaming
ANNOUNCE Opcode
Solicit mappings from PCP-controlled devices:
Used by client to discover mappings after:
- Client restart
- Network change
- Epoch time mismatch
Server responds with all active mappings for client
PCP Options
THIRD_PARTY Option
Allow one host to request mappings for another:
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Option Code=1| Reserved | Option Length=16 |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| Internal IP Address (128 bits) |
| |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Use case:
- NAT gateway requests mapping for internal host
- Application server requests for clients
- Proxy services
PREFER_FAILURE Option
Indicate client prefers error over server changing parameters:
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Option Code=2| Reserved | Option Length=0 |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
With this option:
- Server must honor requested port/IP exactly
- Or return error
- No substitutions allowed
Without this option:
- Server can assign different port/IP
- Client should accept
FILTER Option
Create a firewall filter:
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Option Code=3| Reserved | Option Length=20 |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Reserved | Prefix Length | Remote Peer Port |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
| Remote Peer IP Address (128 bits) |
| |
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Prefix Length:
- 0 = Allow all
- 1-128 = IP prefix match
Use case:
- Restrict mapping to specific source
- Security filtering
- Allow only known peers
PCP Client Implementation
Python Example
import socket
import struct
import random
import time
class PCPClient:
PCP_VERSION = 2
PCP_SERVER_PORT = 5351
OPCODE_MAP = 1
OPCODE_PEER = 2
def __init__(self, server_ip):
self.server_ip = server_ip
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.sock.settimeout(3)
def create_mapping(self, internal_port, external_port=0,
protocol=6, lifetime=3600):
"""
Create a port mapping.
Args:
internal_port: Port on client
external_port: Suggested external port (0 = any)
protocol: 6=TCP, 17=UDP
lifetime: Mapping lifetime in seconds
Returns:
(external_ip, external_port, lifetime)
"""
# Generate random nonce
nonce = random.randint(0, 2**96 - 1)
# Build request
request = self._build_map_request(
nonce, protocol, internal_port,
external_port, lifetime
)
# Send request
self.sock.sendto(request, (self.server_ip, self.PCP_SERVER_PORT))
try:
# Receive response
response, addr = self.sock.recvfrom(1024)
return self._parse_map_response(response, nonce)
except socket.timeout:
raise Exception("PCP request timeout")
def delete_mapping(self, internal_port, protocol=6):
"""Delete a mapping by setting lifetime to 0."""
return self.create_mapping(
internal_port,
protocol=protocol,
lifetime=0
)
def _build_map_request(self, nonce, protocol, internal_port,
external_port, lifetime):
"""Build MAP request packet."""
# Common header
version_r_opcode = (self.PCP_VERSION << 8) | self.OPCODE_MAP
reserved = 0
# Client IP (IPv4-mapped IPv6)
client_ip = self._get_client_ip()
client_ip_bytes = self._ipv4_to_ipv6_mapped(client_ip)
# MAP opcode data
nonce_bytes = nonce.to_bytes(12, 'big')
protocol_byte = protocol
reserved_24 = 0
internal_port_field = internal_port
external_port_field = external_port
external_ip_bytes = bytes(16) # All zeros = any
# Pack request
request = struct.pack(
'!HHI',
version_r_opcode,
reserved,
lifetime
)
request += client_ip_bytes
request += nonce_bytes
request += struct.pack(
'!BxxxHH',
protocol_byte,
internal_port_field,
external_port_field
)
request += external_ip_bytes
return request
def _parse_map_response(self, response, expected_nonce):
"""Parse MAP response packet."""
# Parse common header
version_r_opcode, reserved_result, lifetime, epoch = \
struct.unpack('!HHII', response[:12])
# Extract result code
result = reserved_result & 0xFF
if result != 0:
raise Exception(f"PCP error: result code {result}")
# Skip reserved bytes
offset = 12 + 12 # Header + reserved
# Parse MAP response data
nonce_bytes = response[offset:offset+12]
nonce = int.from_bytes(nonce_bytes, 'big')
if nonce != expected_nonce:
raise Exception("Nonce mismatch")
offset += 12
protocol, internal_port, external_port = \
struct.unpack('!BxxxHH', response[offset:offset+8])
offset += 8
external_ip_bytes = response[offset:offset+16]
external_ip = self._ipv6_mapped_to_ipv4(external_ip_bytes)
return (external_ip, external_port, lifetime)
def _get_client_ip(self):
"""Get client's local IP address."""
# Connect to PCP server to determine local IP
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
s.connect((self.server_ip, self.PCP_SERVER_PORT))
return s.getsockname()[0]
finally:
s.close()
def _ipv4_to_ipv6_mapped(self, ipv4):
"""Convert IPv4 address to IPv4-mapped IPv6."""
parts = [int(p) for p in ipv4.split('.')]
# ::ffff:a.b.c.d
return bytes([0]*10 + [0xff, 0xff] + parts)
def _ipv6_mapped_to_ipv4(self, ipv6_bytes):
"""Convert IPv4-mapped IPv6 to IPv4."""
if ipv6_bytes[:12] == bytes([0]*10 + [0xff, 0xff]):
# IPv4-mapped
return '.'.join(str(b) for b in ipv6_bytes[12:])
else:
# Full IPv6 - return as string
parts = struct.unpack('!8H', ipv6_bytes)
return ':'.join(f'{p:x}' for p in parts)
def close(self):
self.sock.close()
# Usage example
if __name__ == '__main__':
# Find PCP server (usually gateway)
gateway = '192.168.1.1'
client = PCPClient(gateway)
try:
# Create mapping for local port 8080
external_ip, external_port, lifetime = \
client.create_mapping(
internal_port=8080,
external_port=8080, # Suggest same port
protocol=6, # TCP
lifetime=3600 # 1 hour
)
print(f"Mapping created:")
print(f" External: {external_ip}:{external_port}")
print(f" Internal: localhost:8080")
print(f" Lifetime: {lifetime} seconds")
# Keep mapping alive
print("\nMapping active. Press Ctrl+C to delete...")
try:
while True:
# Renew every 30 minutes
time.sleep(1800)
external_ip, external_port, lifetime = \
client.create_mapping(8080, protocol=6, lifetime=3600)
print(f"Mapping renewed: {lifetime}s remaining")
except KeyboardInterrupt:
pass
# Delete mapping
print("\nDeleting mapping...")
client.delete_mapping(8080)
print("Mapping deleted")
finally:
client.close()
Node.js Example
const dgram = require('dgram');
const crypto = require('crypto');
class PCPClient {
constructor(serverIP) {
this.serverIP = serverIP;
this.serverPort = 5351;
this.socket = dgram.createSocket('udp4');
this.PCP_VERSION = 2;
this.OPCODE_MAP = 1;
}
async createMapping(internalPort, externalPort = 0, protocol = 6, lifetime = 3600) {
const nonce = crypto.randomBytes(12);
const request = this.buildMapRequest(
nonce,
protocol,
internalPort,
externalPort,
lifetime
);
return new Promise((resolve, reject) => {
const timeout = setTimeout(() => {
reject(new Error('PCP request timeout'));
}, 3000);
this.socket.once('message', (response) => {
clearTimeout(timeout);
try {
const result = this.parseMapResponse(response, nonce);
resolve(result);
} catch (error) {
reject(error);
}
});
this.socket.send(request, this.serverPort, this.serverIP);
});
}
buildMapRequest(nonce, protocol, internalPort, externalPort, lifetime) {
const buffer = Buffer.alloc(60);
let offset = 0;
// Version and opcode
buffer.writeUInt8(this.PCP_VERSION, offset++);
buffer.writeUInt8(this.OPCODE_MAP, offset++);
// Reserved
buffer.writeUInt16BE(0, offset);
offset += 2;
// Lifetime
buffer.writeUInt32BE(lifetime, offset);
offset += 4;
// Client IP (IPv4-mapped)
buffer.fill(0, offset, offset + 10);
offset += 10;
buffer.writeUInt16BE(0xffff, offset);
offset += 2;
// Would write actual IP here
offset += 4;
// Nonce
nonce.copy(buffer, offset);
offset += 12;
// Protocol
buffer.writeUInt8(protocol, offset);
offset += 4; // 1 byte + 3 reserved
// Ports
buffer.writeUInt16BE(internalPort, offset);
offset += 2;
buffer.writeUInt16BE(externalPort, offset);
offset += 2;
// External IP (all zeros = any)
buffer.fill(0, offset, offset + 16);
return buffer;
}
parseMapResponse(response, expectedNonce) {
let offset = 0;
// Parse header
const version = response.readUInt8(offset++);
const opcode = response.readUInt8(offset++) & 0x7f;
const reserved = response.readUInt8(offset++);
const result = response.readUInt8(offset++);
const lifetime = response.readUInt32BE(offset);
offset += 4;
if (result !== 0) {
throw new Error(`PCP error: result code ${result}`);
}
// Skip epoch and reserved
offset += 16;
// Check nonce
const nonce = response.slice(offset, offset + 12);
if (!nonce.equals(expectedNonce)) {
throw new Error('Nonce mismatch');
}
offset += 12;
// Parse MAP data
const protocol = response.readUInt8(offset);
offset += 4; // 1 byte + 3 reserved
const internalPort = response.readUInt16BE(offset);
offset += 2;
const externalPort = response.readUInt16BE(offset);
offset += 2;
// External IP
const externalIP = this.parseIP(response.slice(offset, offset + 16));
return { externalIP, externalPort, lifetime };
}
parseIP(buffer) {
// Check if IPv4-mapped
if (buffer.slice(0, 12).equals(Buffer.from([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff]))) {
return `${buffer[12]}.${buffer[13]}.${buffer[14]}.${buffer[15]}`;
}
// IPv6
const parts = [];
for (let i = 0; i < 16; i += 2) {
parts.push(buffer.readUInt16BE(i).toString(16));
}
return parts.join(':');
}
close() {
this.socket.close();
}
}
// Usage
const client = new PCPClient('192.168.1.1');
client.createMapping(8080, 8080, 6, 3600)
.then(result => {
console.log('Mapping created:');
console.log(` External: ${result.externalIP}:${result.externalPort}`);
console.log(` Lifetime: ${result.lifetime}s`);
})
.catch(error => {
console.error('Error:', error.message);
})
.finally(() => {
client.close();
});
PCP Server Discovery
Methods to find PCP server:
1. DHCP Option
- Option 128 (DHCPv4)
- Option 86 (DHCPv6)
- Contains PCP server IP address
2. Default Gateway
- Try gateway address first
- Most common case
3. Well-Known Anycast Address
- IPv4: (none defined)
- IPv6: (none defined yet)
4. Manual Configuration
- User configures PCP server
- For complex networks
Discovery process:
1. Check DHCP options
2. Try default gateway
3. Try manual config
4. Give up (no PCP available)
Security Considerations
Authentication:
- PCP has no built-in authentication
- Relies on network trust
- Server trusts requests from local network
Threats:
1. Unauthorized mappings
- Malware opens ports
- Mitigation: Firewall rules on server
2. Mapping hijacking
- Another host modifies mapping
- Mitigation: Nonce verification
3. Denial of service
- Exhaust mapping resources
- Mitigation: Per-client quotas
4. Information disclosure
- Reveal internal topology
- Mitigation: Restrict query responses
Best practices:
- Deploy PCP-aware firewall
- Monitor mapping activity
- Set reasonable quotas
- Log suspicious requests
- Use short lifetimes
Common Use Cases
1. Gaming
# Game server
pcp = PCPClient('192.168.1.1')
# Create mapping for game server
external_ip, external_port, _ = pcp.create_mapping(
internal_port=27015, # Game server port
external_port=27015,
protocol=17, # UDP
lifetime=7200 # 2 hours
)
print(f"Server address: {external_ip}:{external_port}")
print("Share this with friends to join!")
# Register with matchmaking
register_with_matchmaking(external_ip, external_port)
# Keep mapping alive
while game_running:
time.sleep(3600)
pcp.create_mapping(27015, protocol=17, lifetime=7200)
2. P2P Applications
# P2P file sharing
pcp = PCPClient(gateway)
# Create PEER mapping for specific peer
peer_ip = '203.0.113.50'
peer_port = 6881
mapping = pcp.create_peer_mapping(
internal_port=6881,
peer_ip=peer_ip,
peer_port=peer_port,
protocol=6, # TCP
lifetime=3600
)
print(f"Connected to peer: {peer_ip}:{peer_port}")
print(f"Via external: {mapping['external_ip']}:{mapping['external_port']}")
3. IoT Devices
# Smart home device
pcp = PCPClient(gateway)
# Create long-lived mapping
external_ip, external_port, lifetime = pcp.create_mapping(
internal_port=8883, # MQTT over TLS
protocol=6,
lifetime=86400 # 24 hours
)
# Register with cloud service
register_device(device_id, external_ip, external_port)
# Renew daily
schedule_renewal(pcp, 8883, 86400)
Troubleshooting
# Check if PCP server is responding
nc -u 192.168.1.1 5351
# Send test request (hex)
echo -n "020100000000..." | nc -u 192.168.1.1 5351
# tcpdump PCP traffic
sudo tcpdump -i any -n udp port 5351
# Example output:
# Request
# 02 01 00 00 00 0e 10 00 # Version, opcode, reserved, lifetime
# 00 00 00 00 00 00 00 00 # Client IP (first 8 bytes)
# 00 00 ff ff c0 a8 01 64 # Client IP (last 8 bytes)
# ...
# Check router logs
# Look for "PCP" or "port mapping"
# Test with pcpdump (if available)
pcpdump -i eth0
# Common issues:
# - Router doesn't support PCP
# - PCP disabled in router config
# - Firewall blocks UDP 5351
# - Multiple NATs in path
# - Quota exceeded
ELI10: PCP Explained Simply
PCP is like asking the gatekeeper to let your friends visit:
Without PCP (Manual)
You: "Mom, can you open the door at 3pm for my friend?"
Mom: Manually opens door at 3pm
Friend: Can enter
Problem: Mom must remember, manual work
With PCP (Automatic)
You: "Open door for 2 hours when friend arrives"
Smart Lock: Automatically opens
Friend: Arrives, enters
Smart Lock: Closes after 2 hours
Benefits:
- Automatic
- Time-limited
- You control it
- No manual work
Real Network
Your App: "Need port 8080 open for 1 hour"
Router: Creates port mapping
Internet: Can now reach your app
Router: Closes port after 1 hour
Secure because:
- Time-limited
- Application controlled
- Automatic cleanup
Further Resources
Specifications
Implementations
Tools
- pcpdump - PCP packet analyzer
- pcptest - PCP testing tool
Comparison
NAT-PMP (NAT Port Mapping Protocol)
Overview
NAT-PMP (NAT Port Mapping Protocol) is a network protocol for establishing port forwarding rules in a NAT gateway automatically. It provides a simple, lightweight mechanism for applications to request port mappings without manual configuration. NAT-PMP was developed by Apple and later standardized as RFC 6886.
Key Characteristics
Protocol: UDP
Port: 5351
RFC: 6886 (2013)
Developed by: Apple Inc.
Successor: PCP (Port Control Protocol)
Features:
✓ Automatic port mapping
✓ Simple protocol (easy to implement)
✓ UDP-based (low overhead)
✓ Time-limited mappings
✓ Gateway discovery
✓ External address discovery
✓ Lightweight
Limitations:
✗ IPv4 only
✗ Single NAT only
✗ No authentication
✗ Limited features vs PCP
Why NAT-PMP?
The Problem
Traditional Port Forwarding:
1. User manually logs into router
2. Navigates to port forwarding settings
3. Adds rule: External Port → Internal IP:Port
4. Application must document this for users
5. Users often configure incorrectly
6. Ports left open indefinitely
Issues:
- Not user-friendly
- Security risk (forgotten mappings)
- Doesn't work for non-technical users
- Can't be automated by applications
NAT-PMP Solution
Automatic Approach:
1. Application requests mapping via NAT-PMP
2. Router creates mapping automatically
3. Mapping has expiration time
4. Application renews as needed
5. Mapping removed when no longer needed
Benefits:
✓ Zero user configuration
✓ Automatic cleanup
✓ Application-controlled
✓ Simple to implement
✓ Secure (time-limited)
NAT-PMP vs Alternatives
Feature NAT-PMP UPnP-IGD PCP
Protocol UDP HTTP/SOAP UDP
Complexity Low High Medium
IPv6 Support No Partial Yes
Port 5351 Variable 5351
Packet Size 12 bytes KB+ 24+ bytes
Overhead Minimal High Low
Deployment Apple Wide Growing
Year Introduced 2005 2000 2013
Use NAT-PMP when:
- IPv4 only network
- Simple requirements
- Apple ecosystem
- Lightweight solution
- Easy implementation
Use PCP when:
- Need IPv6
- Modern deployment
- Advanced features
- Multiple NATs
Use UPnP when:
- Legacy compatibility
- Already deployed
- Complex scenarios
Protocol Design
Message Types
Request Types (Client → NAT Gateway):
- Opcode 0: Determine external IP address
- Opcode 1: Map UDP port
- Opcode 2: Map TCP port
Response Types (NAT Gateway → Client):
- Opcode 128: External IP address response
- Opcode 129: UDP port mapping response
- Opcode 130: TCP port mapping response
All opcodes in responses have bit 7 set (add 128)
Packet Format
All NAT-PMP packets start with:
0 1
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Version = 0 | Opcode |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Version: Always 0
Opcode: Request or response type
External IP Address Request
Request Format
0 1
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Version = 0 | Opcode = 0 |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Total: 2 bytes
Purpose:
- Discover NAT gateway's external IP
- Check if NAT-PMP is supported
- Verify connectivity to gateway
Response Format
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Version = 0 | Opcode = 128 | Result Code |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Seconds Since Start of Epoch |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| External IP Address |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Total: 12 bytes
Result Code:
0: Success
1: Unsupported Version
2: Not Authorized/Refused
3: Network Failure
4: Out of Resources
5: Unsupported Opcode
Seconds Since Start of Epoch:
- Time since gateway booted/restarted
- Used to detect gateway reboots
- Incremented every second
External IP Address:
- Gateway's public IP address
- 32-bit IPv4 address
Port Mapping Request
Request Format
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Version = 0 | Opcode (1/2) | Reserved (0) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Internal Port | Suggested External Port |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Requested Port Mapping Lifetime |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Total: 12 bytes
Opcode:
- 1 = UDP port mapping
- 2 = TCP port mapping
Reserved: Must be 0
Internal Port:
- Port on the client machine
- Port application is listening on
Suggested External Port:
- Preferred external port
- 0 = gateway chooses
- Non-zero = client preference
Requested Lifetime:
- Duration in seconds
- 0 = delete mapping
- Recommended: 3600 (1 hour)
- Maximum: 2^32 - 1 seconds
Response Format
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Version = 0 | Opcode (129/130) | Result Code |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Seconds Since Start of Epoch |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Internal Port | Mapped External Port |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Port Mapping Lifetime (seconds) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Total: 16 bytes
Opcode:
- 129 = UDP port mapping response
- 130 = TCP port mapping response
Mapped External Port:
- Actual external port assigned
- May differ from suggested port
- 0 = mapping failed or deleted
Port Mapping Lifetime:
- Actual lifetime granted
- May be less than requested
- Gateway may reduce based on policy
Gateway Discovery
How to Find NAT Gateway
Method 1: Default Gateway (Recommended)
- Use system's default gateway
- Most common case
- Works in 99% of deployments
import socket
import struct
def get_default_gateway():
"""Get default gateway IP (Linux)."""
with open('/proc/net/route') as f:
for line in f:
fields = line.strip().split()
if fields[1] == '00000000': # Default route
gateway_hex = fields[2]
# Convert hex to IP
gateway_int = int(gateway_hex, 16)
return socket.inet_ntoa(struct.pack('<I', gateway_int))
return None
# Or use netifaces library
import netifaces
gws = netifaces.gateways()
gateway = gws['default'][netifaces.AF_INET][0]
Method 2: DHCP Option
- DHCP Option 120 (NAT-PMP Gateway)
- Rarely used in practice
Method 3: Multicast (Legacy)
- Send to 224.0.0.1 (all hosts)
- Gateway responds
- Not recommended
Best Practice:
Always try default gateway first
Client Implementation
Python Example
import socket
import struct
import time
class NATPMPClient:
def __init__(self, gateway_ip):
self.gateway_ip = gateway_ip
self.gateway_port = 5351
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.sock.settimeout(3.0)
def get_external_ip(self):
"""
Get NAT gateway's external IP address.
Returns:
(external_ip, epoch_seconds)
"""
# Build request
request = struct.pack('!BB', 0, 0) # Version 0, Opcode 0
# Send request
self.sock.sendto(request, (self.gateway_ip, self.gateway_port))
try:
# Receive response
response, addr = self.sock.recvfrom(1024)
# Parse response
if len(response) < 12:
raise Exception("Invalid response length")
version, opcode, result_code, epoch, ext_ip = \
struct.unpack('!BBHII', response)
if result_code != 0:
raise Exception(f"Error: result code {result_code}")
# Convert IP to string
external_ip = socket.inet_ntoa(struct.pack('!I', ext_ip))
return (external_ip, epoch)
except socket.timeout:
raise Exception("Request timeout - NAT-PMP not supported?")
def add_port_mapping(self, internal_port, external_port=0,
protocol='tcp', lifetime=3600):
"""
Add a port mapping.
Args:
internal_port: Port on local machine
external_port: Desired external port (0 = any)
protocol: 'tcp' or 'udp'
lifetime: Mapping duration in seconds (0 = delete)
Returns:
(mapped_external_port, actual_lifetime, epoch)
"""
# Build request
opcode = 1 if protocol == 'udp' else 2
request = struct.pack(
'!BBHHHI',
0, # Version
opcode, # 1=UDP, 2=TCP
0, # Reserved
internal_port,
external_port,
lifetime
)
# Send request
self.sock.sendto(request, (self.gateway_ip, self.gateway_port))
try:
# Receive response
response, addr = self.sock.recvfrom(1024)
# Parse response
if len(response) < 16:
raise Exception("Invalid response length")
version, resp_opcode, result_code, epoch, \
int_port, ext_port, actual_lifetime = \
struct.unpack('!BBHIHHI', response)
if result_code != 0:
raise Exception(f"Error: result code {result_code}")
return (ext_port, actual_lifetime, epoch)
except socket.timeout:
raise Exception("Request timeout")
def delete_port_mapping(self, internal_port, protocol='tcp'):
"""Delete a port mapping by setting lifetime to 0."""
return self.add_port_mapping(
internal_port,
external_port=0,
protocol=protocol,
lifetime=0
)
def close(self):
self.sock.close()
# Usage example
if __name__ == '__main__':
# Get gateway from system
import netifaces
gws = netifaces.gateways()
gateway = gws['default'][netifaces.AF_INET][0]
print(f"Using gateway: {gateway}")
client = NATPMPClient(gateway)
try:
# Get external IP
external_ip, epoch = client.get_external_ip()
print(f"External IP: {external_ip}")
print(f"Gateway uptime: {epoch} seconds")
# Add port mapping
print("\nCreating port mapping...")
external_port, lifetime, epoch = client.add_port_mapping(
internal_port=8080,
external_port=8080, # Prefer 8080
protocol='tcp',
lifetime=3600 # 1 hour
)
print(f"Mapping created:")
print(f" Internal: localhost:8080")
print(f" External: {external_ip}:{external_port}")
print(f" Lifetime: {lifetime} seconds")
# Keep mapping alive
print("\nMapping active. Press Ctrl+C to delete...")
try:
last_epoch = epoch
while True:
time.sleep(1800) # Renew every 30 minutes
# Renew mapping
external_port, lifetime, epoch = client.add_port_mapping(
internal_port=8080,
protocol='tcp',
lifetime=3600
)
# Check for gateway reboot
if epoch < last_epoch:
print("Warning: Gateway rebooted! Mapping recreated.")
last_epoch = epoch
print(f"Mapping renewed: {lifetime}s remaining")
except KeyboardInterrupt:
pass
# Delete mapping
print("\nDeleting mapping...")
client.delete_port_mapping(8080, 'tcp')
print("Mapping deleted")
except Exception as e:
print(f"Error: {e}")
finally:
client.close()
JavaScript/Node.js Example
const dgram = require('dgram');
class NATPMPClient {
constructor(gatewayIP) {
this.gatewayIP = gatewayIP;
this.gatewayPort = 5351;
this.socket = dgram.createSocket('udp4');
}
getExternalIP() {
return new Promise((resolve, reject) => {
// Build request
const request = Buffer.alloc(2);
request.writeUInt8(0, 0); // Version
request.writeUInt8(0, 1); // Opcode
const timeout = setTimeout(() => {
reject(new Error('Request timeout'));
}, 3000);
this.socket.once('message', (response) => {
clearTimeout(timeout);
try {
const version = response.readUInt8(0);
const opcode = response.readUInt8(1);
const resultCode = response.readUInt16BE(2);
if (resultCode !== 0) {
throw new Error(`Error: result code ${resultCode}`);
}
const epoch = response.readUInt32BE(4);
const ipBytes = [
response.readUInt8(8),
response.readUInt8(9),
response.readUInt8(10),
response.readUInt8(11)
];
const externalIP = ipBytes.join('.');
resolve({ externalIP, epoch });
} catch (error) {
reject(error);
}
});
this.socket.send(request, this.gatewayPort, this.gatewayIP);
});
}
addPortMapping(internalPort, externalPort = 0, protocol = 'tcp', lifetime = 3600) {
return new Promise((resolve, reject) => {
// Build request
const request = Buffer.alloc(12);
request.writeUInt8(0, 0); // Version
request.writeUInt8(protocol === 'udp' ? 1 : 2, 1); // Opcode
request.writeUInt16BE(0, 2); // Reserved
request.writeUInt16BE(internalPort, 4);
request.writeUInt16BE(externalPort, 6);
request.writeUInt32BE(lifetime, 8);
const timeout = setTimeout(() => {
reject(new Error('Request timeout'));
}, 3000);
this.socket.once('message', (response) => {
clearTimeout(timeout);
try {
const resultCode = response.readUInt16BE(2);
if (resultCode !== 0) {
throw new Error(`Error: result code ${resultCode}`);
}
const epoch = response.readUInt32BE(4);
const mappedPort = response.readUInt16BE(10);
const actualLifetime = response.readUInt32BE(12);
resolve({
externalPort: mappedPort,
lifetime: actualLifetime,
epoch
});
} catch (error) {
reject(error);
}
});
this.socket.send(request, this.gatewayPort, this.gatewayIP);
});
}
deletePortMapping(internalPort, protocol = 'tcp') {
return this.addPortMapping(internalPort, 0, protocol, 0);
}
close() {
this.socket.close();
}
}
// Usage
const os = require('os');
function getDefaultGateway() {
// Simple gateway detection (platform-specific)
const interfaces = os.networkInterfaces();
// This is simplified - use proper gateway detection in production
return '192.168.1.1';
}
const gateway = getDefaultGateway();
const client = new NATPMPClient(gateway);
async function main() {
try {
// Get external IP
const { externalIP, epoch } = await client.getExternalIP();
console.log(`External IP: ${externalIP}`);
console.log(`Gateway uptime: ${epoch}s`);
// Add port mapping
const mapping = await client.addPortMapping(8080, 8080, 'tcp', 3600);
console.log('Mapping created:');
console.log(` External: ${externalIP}:${mapping.externalPort}`);
console.log(` Lifetime: ${mapping.lifetime}s`);
// Renew periodically
setInterval(async () => {
const renewed = await client.addPortMapping(8080, 8080, 'tcp', 3600);
console.log(`Mapping renewed: ${renewed.lifetime}s`);
}, 30 * 60 * 1000); // Every 30 minutes
} catch (error) {
console.error('Error:', error.message);
}
}
main();
Mapping Lifetime Management
Recommended Practices
1. Initial Lifetime
- Request 3600 seconds (1 hour)
- Gateway may grant less
- Never request > 1 day
2. Renewal Strategy
- Renew at 50% of lifetime
- If lifetime is 3600s, renew at 1800s
- Provides safety margin
3. Exponential Backoff
- If renewal fails, retry with backoff
- 1s, 2s, 4s, 8s, 16s, 32s
- Eventually recreate mapping
4. Epoch Monitoring
- Check epoch in each response
- If epoch < last_epoch: gateway rebooted
- Recreate all mappings
5. Cleanup
- Always delete mappings when done
- Set lifetime=0 to delete
- Graceful shutdown
Example: Lifetime Management
class MappingManager:
def __init__(self, client, internal_port, protocol='tcp'):
self.client = client
self.internal_port = internal_port
self.protocol = protocol
self.external_port = None
self.lifetime = None
self.last_epoch = None
self.running = False
def start(self):
"""Create and maintain mapping."""
self.running = True
# Create initial mapping
self._create_mapping()
# Renewal loop
while self.running:
# Sleep for half of lifetime
sleep_time = self.lifetime / 2
time.sleep(sleep_time)
if not self.running:
break
try:
# Renew mapping
ext_port, lifetime, epoch = self.client.add_port_mapping(
self.internal_port,
self.external_port, # Request same port
self.protocol,
3600
)
# Check for gateway reboot
if epoch < self.last_epoch:
print("Gateway rebooted - mapping recreated")
self.external_port = ext_port
self.lifetime = lifetime
self.last_epoch = epoch
print(f"Mapping renewed: {lifetime}s")
except Exception as e:
print(f"Renewal failed: {e}")
# Retry with backoff
self._retry_with_backoff()
def _create_mapping(self):
"""Create initial mapping."""
ext_port, lifetime, epoch = self.client.add_port_mapping(
self.internal_port,
0, # Any port
self.protocol,
3600
)
self.external_port = ext_port
self.lifetime = lifetime
self.last_epoch = epoch
print(f"Mapping created: :{ext_port} -> localhost:{self.internal_port}")
def _retry_with_backoff(self):
"""Retry with exponential backoff."""
delays = [1, 2, 4, 8, 16, 32]
for delay in delays:
time.sleep(delay)
try:
self._create_mapping()
return
except Exception as e:
print(f"Retry failed: {e}")
print("All retries failed")
self.running = False
def stop(self):
"""Stop and delete mapping."""
self.running = False
try:
self.client.delete_port_mapping(self.internal_port, self.protocol)
print("Mapping deleted")
except Exception as e:
print(f"Failed to delete mapping: {e}")
# Usage
client = NATPMPClient(gateway)
manager = MappingManager(client, 8080, 'tcp')
# Start in background thread
import threading
thread = threading.Thread(target=manager.start)
thread.start()
# Application runs...
# Cleanup on exit
manager.stop()
thread.join()
client.close()
Security Considerations
Threats:
1. Unauthorized Mappings
- Malware can open ports
- No authentication in protocol
- Mitigation: Monitor gateway logs
2. Resource Exhaustion
- Many mappings consume gateway resources
- DoS via mapping requests
- Mitigation: Gateway enforces limits
3. Information Disclosure
- External IP revealed
- Internal topology visible
- Mitigation: Minimal, inherent to NAT
4. Spoofing
- Off-path attacker sends fake responses
- Mitigation: Check source IP/port
Best Practices:
1. Only request needed mappings
2. Use shortest lifetime necessary
3. Delete mappings when done
4. Monitor for unexpected mappings
5. Validate response source
6. Handle errors gracefully
Troubleshooting
# Test if gateway supports NAT-PMP
nc -u 192.168.1.1 5351
# Send external IP request (hex)
echo -n "\x00\x00" | nc -u 192.168.1.1 5351
# Expected response (hex):
# 00 80 00 00 SSSS SSSS EE EE EE EE
# 00: Version
# 80: Opcode (128 = external IP response)
# 00 00: Result (success)
# SSSS SSSS: Epoch seconds
# EE EE EE EE: External IP
# tcpdump NAT-PMP traffic
sudo tcpdump -i any -n udp port 5351 -X
# Check if gateway has NAT-PMP enabled
# Router admin interface → Port Forwarding → NAT-PMP
# Common issues:
# - Gateway doesn't support NAT-PMP
# - NAT-PMP disabled in gateway
# - Firewall blocks UDP 5351
# - Wrong gateway address
# - Gateway behind another NAT
# Test with real client
pip install nat-pmp
natpmpc -g 192.168.1.1 -a 8080 8080 tcp 3600
Comparison with Other Protocols
NAT-PMP vs PCP
NAT-PMP:
+ Simple, easy to implement
+ Low overhead (12-16 bytes)
+ Widely supported (Apple devices)
+ Battle-tested (since 2005)
- IPv4 only
- Single NAT only
- Limited features
PCP:
+ IPv4 and IPv6
+ Multiple NATs
+ More features (PEER, filters)
+ Modern design
- More complex
- Less deployed
- Larger packets
Migration Path:
- PCP designed as NAT-PMP successor
- PCP port (5351) intentionally same
- Clients can try both
Feature Comparison
Feature NAT-PMP PCP UPnP-IGD
Packet Size 12-16B 24+B KB+
Round Trips 1 1 Multiple
IPv6 No Yes Partial
Lifetime Management Yes Yes No
Third-party Mapping No Yes No
Firewall Control No Yes No
Authentication No No No
Complexity Low Medium High
Apple Support Native Native Emulated
Linux Support Good Good Good
Common Use Cases
1. BitTorrent Client
# BitTorrent client
client = NATPMPClient(gateway)
# Map port for incoming connections
port = 6881
ext_port, lifetime, _ = client.add_port_mapping(
internal_port=port,
external_port=port,
protocol='tcp',
lifetime=7200 # 2 hours
)
print(f"Listening on port {ext_port}")
# Announce to tracker with external port
announce_to_tracker(ext_port)
# Maintain mapping while downloading
while downloading:
time.sleep(3600)
client.add_port_mapping(port, protocol='tcp', lifetime=7200)
# Cleanup
client.delete_port_mapping(port, 'tcp')
2. VoIP Application
# VoIP client
client = NATPMPClient(gateway)
# Map SIP and RTP ports
sip_port = 5060
rtp_port = 16384
# SIP (TCP)
sip_ext, _, _ = client.add_port_mapping(
sip_port, sip_port, 'tcp', 3600
)
# RTP (UDP)
rtp_ext, _, _ = client.add_port_mapping(
rtp_port, rtp_port, 'udp', 3600
)
# Register with external address
external_ip, _ = client.get_external_ip()
register_with_server(external_ip, sip_ext, rtp_ext)
3. Game Server
# Game server
client = NATPMPClient(gateway)
# Map game port
game_port = 27015
ext_port, lifetime, _ = client.add_port_mapping(
game_port, game_port, 'udp', 7200
)
external_ip, _ = client.get_external_ip()
# Advertise server
advertise_server(f"{external_ip}:{ext_port}")
print(f"Server accessible at {external_ip}:{ext_port}")
ELI10: NAT-PMP Explained Simply
NAT-PMP is like asking your house to automatically open a window:
Without NAT-PMP
You: Want friend to visit
Problem: Door is locked
Solution: Ask parent to unlock door manually
Issue: Parent must remember, manual work
With NAT-PMP
You: "Please open door for 1 hour"
Smart House: Opens door automatically
Friend: Can enter for 1 hour
Smart House: Locks door after 1 hour
Automatic + Safe!
In Computer Terms
Your App: "Need port 8080 open for 1 hour"
Router: Opens port 8080 automatically
Internet: Can now reach your app on port 8080
Router: Closes port after 1 hour
Benefits:
- No manual configuration
- Automatic cleanup
- Time-limited (secure)
- Application controls it
Further Resources
Specifications
- RFC 6886 - NAT-PMP
- RFC 6887 - PCP (Successor)
Implementations
Tools
- natpmpc - Command-line client
- NAT Port Mapping Protocol
Apple Documentation
- NAT-PMP on macOS
- Bonjour implementation includes NAT-PMP
UPnP (Universal Plug and Play)
Overview
UPnP is a set of networking protocols that enables devices on a network to seamlessly discover each other and establish functional network services for data sharing, communications, and entertainment. It allows devices to automatically configure themselves and announce their presence to other devices.
UPnP Components
1. Discovery (SSDP)
- Find devices on network
- Announce presence
2. Description
- Device capabilities
- Services offered
3. Control
- Invoke actions
- Query state
4. Eventing
- Subscribe to state changes
- Receive notifications
5. Presentation
- Web-based UI
- Human interaction
UPnP Architecture
Control Point (Client) Device (Server)
| |
| 1. Discovery (SSDP) |
|<-------------------------->|
| |
| 2. Description (XML) |
|--------------------------->|
|<---------------------------|
| |
| 3. Control (SOAP) |
|--------------------------->|
|<---------------------------|
| |
| 4. Eventing (GENA) |
|--------------------------->|
| (Subscribe) |
|<---------------------------|
| (Events) |
SSDP (Simple Service Discovery Protocol)
Discovery Process
Device Announcement:
Device joins network:
NOTIFY * HTTP/1.1
Host: 239.255.255.250:1900
Cache-Control: max-age=1800
Location: http://192.168.1.100:8080/description.xml
NT: upnp:rootdevice
NTS: ssdp:alive
Server: Linux/5.4 UPnP/1.0 MyDevice/1.0
USN: uuid:12345678-1234-1234-1234-123456789abc::upnp:rootdevice
Sent to multicast address 239.255.255.250:1900
Announces device presence
Device Search (M-SEARCH):
Control point searches for devices:
M-SEARCH * HTTP/1.1
Host: 239.255.255.250:1900
Man: "ssdp:discover"
ST: ssdp:all
MX: 3
(Search for all devices, wait up to 3 seconds)
Multicast to 239.255.255.250:1900
Device Response:
HTTP/1.1 200 OK
Cache-Control: max-age=1800
Location: http://192.168.1.100:8080/description.xml
Server: Linux/5.4 UPnP/1.0 MyDevice/1.0
ST: upnp:rootdevice
USN: uuid:12345678-1234-1234-1234-123456789abc::upnp:rootdevice
Unicast response back to control point
SSDP Multicast
IPv4 Address: 239.255.255.250
Port: 1900 (UDP)
All UPnP devices listen on this address
Used for discovery announcements
Search Targets (ST)
ssdp:all - All devices and services
upnp:rootdevice - Root devices only
uuid:<device-uuid> - Specific device
urn:schemas-upnp-org:device:<deviceType>:<version>
urn:schemas-upnp-org:service:<serviceType>:<version>
Examples:
ST: urn:schemas-upnp-org:device:MediaRenderer:1
ST: urn:schemas-upnp-org:service:ContentDirectory:1
Device Description
Description XML
<?xml version="1.0"?>
<root xmlns="urn:schemas-upnp-org:device-1-0">
<specVersion>
<major>1</major>
<minor>0</minor>
</specVersion>
<device>
<deviceType>urn:schemas-upnp-org:device:MediaRenderer:1</deviceType>
<friendlyName>Living Room TV</friendlyName>
<manufacturer>Samsung</manufacturer>
<manufacturerURL>http://www.samsung.com</manufacturerURL>
<modelDescription>Smart TV</modelDescription>
<modelName>UN55TU8000</modelName>
<modelNumber>8000</modelNumber>
<serialNumber>123456789</serialNumber>
<UDN>uuid:12345678-1234-1234-1234-123456789abc</UDN>
<presentationURL>http://192.168.1.100:8080/</presentationURL>
<serviceList>
<service>
<serviceType>urn:schemas-upnp-org:service:AVTransport:1</serviceType>
<serviceId>urn:upnp-org:serviceId:AVTransport</serviceId>
<SCPDURL>/service/AVTransport/scpd.xml</SCPDURL>
<controlURL>/service/AVTransport/control</controlURL>
<eventSubURL>/service/AVTransport/event</eventSubURL>
</service>
</serviceList>
</device>
</root>
Service Description (SCPD)
<?xml version="1.0"?>
<scpd xmlns="urn:schemas-upnp-org:service-1-0">
<specVersion>
<major>1</major>
<minor>0</minor>
</specVersion>
<actionList>
<action>
<name>Play</name>
<argumentList>
<argument>
<name>Speed</name>
<direction>in</direction>
<relatedStateVariable>TransportPlaySpeed</relatedStateVariable>
</argument>
</argumentList>
</action>
<action>
<name>Stop</name>
</action>
</actionList>
<serviceStateTable>
<stateVariable sendEvents="yes">
<name>TransportState</name>
<dataType>string</dataType>
<allowedValueList>
<allowedValue>PLAYING</allowedValue>
<allowedValue>STOPPED</allowedValue>
<allowedValue>PAUSED_PLAYBACK</allowedValue>
</allowedValueList>
</stateVariable>
<stateVariable sendEvents="no">
<name>TransportPlaySpeed</name>
<dataType>string</dataType>
<defaultValue>1</defaultValue>
</stateVariable>
</serviceStateTable>
</scpd>
UPnP Control (SOAP)
Action Invocation
Request:
POST /service/AVTransport/control HTTP/1.1
Host: 192.168.1.100:8080
Content-Type: text/xml; charset="utf-8"
SOAPAction: "urn:schemas-upnp-org:service:AVTransport:1#Play"
Content-Length: 299
<?xml version="1.0"?>
<s:Envelope
xmlns:s="http://schemas.xmlsoap.org/soap/envelope/"
s:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/">
<s:Body>
<u:Play xmlns:u="urn:schemas-upnp-org:service:AVTransport:1">
<InstanceID>0</InstanceID>
<Speed>1</Speed>
</u:Play>
</s:Body>
</s:Envelope>
Response:
HTTP/1.1 200 OK
Content-Type: text/xml; charset="utf-8"
Content-Length: 250
<?xml version="1.0"?>
<s:Envelope
xmlns:s="http://schemas.xmlsoap.org/soap/envelope/"
s:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/">
<s:Body>
<u:PlayResponse xmlns:u="urn:schemas-upnp-org:service:AVTransport:1">
</u:PlayResponse>
</s:Body>
</s:Envelope>
Error Response
HTTP/1.1 500 Internal Server Error
Content-Type: text/xml; charset="utf-8"
<?xml version="1.0"?>
<s:Envelope
xmlns:s="http://schemas.xmlsoap.org/soap/envelope/"
s:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/">
<s:Body>
<s:Fault>
<faultcode>s:Client</faultcode>
<faultstring>UPnPError</faultstring>
<detail>
<UPnPError xmlns="urn:schemas-upnp-org:control-1-0">
<errorCode>701</errorCode>
<errorDescription>Transition not available</errorDescription>
</UPnPError>
</detail>
</s:Fault>
</s:Body>
</s:Envelope>
UPnP Eventing (GENA)
Subscribe to Events
Request:
SUBSCRIBE /service/AVTransport/event HTTP/1.1
Host: 192.168.1.100:8080
Callback: <http://192.168.1.50:8888/notify>
NT: upnp:event
Timeout: Second-1800
Response:
HTTP/1.1 200 OK
SID: uuid:subscription-12345
Timeout: Second-1800
Initial Event (State Snapshot)
NOTIFY /notify HTTP/1.1
Host: 192.168.1.50:8888
Content-Type: text/xml
NT: upnp:event
NTS: upnp:propchange
SID: uuid:subscription-12345
SEQ: 0
<?xml version="1.0"?>
<e:propertyset xmlns:e="urn:schemas-upnp-org:event-1-0">
<e:property>
<TransportState>STOPPED</TransportState>
</e:property>
<e:property>
<CurrentTrack>1</CurrentTrack>
</e:property>
</e:propertyset>
Subsequent Events
NOTIFY /notify HTTP/1.1
Host: 192.168.1.50:8888
Content-Type: text/xml
NT: upnp:event
NTS: upnp:propchange
SID: uuid:subscription-12345
SEQ: 1
<?xml version="1.0"?>
<e:propertyset xmlns:e="urn:schemas-upnp-org:event-1-0">
<e:property>
<TransportState>PLAYING</TransportState>
</e:property>
</e:propertyset>
Unsubscribe
UNSUBSCRIBE /service/AVTransport/event HTTP/1.1
Host: 192.168.1.100:8080
SID: uuid:subscription-12345
UPnP IGD (Internet Gateway Device)
Port Mapping
One of the most common UPnP uses:
Add Port Mapping Request:
POST /control/WANIPConnection HTTP/1.1
Host: 192.168.1.1:5000
Content-Type: text/xml; charset="utf-8"
SOAPAction: "urn:schemas-upnp-org:service:WANIPConnection:1#AddPortMapping"
<?xml version="1.0"?>
<s:Envelope ...>
<s:Body>
<u:AddPortMapping xmlns:u="urn:schemas-upnp-org:service:WANIPConnection:1">
<NewRemoteHost></NewRemoteHost>
<NewExternalPort>8080</NewExternalPort>
<NewProtocol>TCP</NewProtocol>
<NewInternalPort>8080</NewInternalPort>
<NewInternalClient>192.168.1.50</NewInternalClient>
<NewEnabled>1</NewEnabled>
<NewPortMappingDescription>My Web Server</NewPortMappingDescription>
<NewLeaseDuration>0</NewLeaseDuration>
</u:AddPortMapping>
</s:Body>
</s:Envelope>
Result:
External: <public-ip>:8080
↓
Internal: 192.168.1.50:8080
Automatic NAT traversal!
Get External IP
POST /control/WANIPConnection HTTP/1.1
SOAPAction: "urn:schemas-upnp-org:service:WANIPConnection:1#GetExternalIPAddress"
<u:GetExternalIPAddress xmlns:u="urn:schemas-upnp-org:service:WANIPConnection:1">
</u:GetExternalIPAddress>
Response:
<u:GetExternalIPAddressResponse>
<NewExternalIPAddress>203.0.113.5</NewExternalIPAddress>
</u:GetExternalIPAddressResponse>
Common UPnP Device Types
MediaServer - Content provider (NAS, PC)
MediaRenderer - Content consumer (TV, speaker)
InternetGatewayDevice - Router/NAT
WANConnectionDevice - WAN connection management
PrinterBasic - Network printer
Scanner - Network scanner
HVAC - Heating/cooling control
Lighting - Smart lights
SecurityDevice - Cameras, sensors
UPnP Client Implementation
Python Example (Discovery)
import socket
SSDP_ADDR = '239.255.255.250'
SSDP_PORT = 1900
def discover_devices():
# M-SEARCH message
message = '\r\n'.join([
'M-SEARCH * HTTP/1.1',
f'Host: {SSDP_ADDR}:{SSDP_PORT}',
'Man: "ssdp:discover"',
'ST: ssdp:all',
'MX: 3',
'',
''
])
# Create socket
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(5)
# Send M-SEARCH
sock.sendto(message.encode(), (SSDP_ADDR, SSDP_PORT))
# Receive responses
devices = []
try:
while True:
data, addr = sock.recvfrom(1024)
response = data.decode()
# Parse location
for line in response.split('\r\n'):
if line.startswith('Location:'):
location = line.split(':', 1)[1].strip()
devices.append(location)
break
except socket.timeout:
pass
sock.close()
return devices
# Usage
devices = discover_devices()
for device in devices:
print(f"Found device: {device}")
Python Example (Control)
import requests
import xml.etree.ElementTree as ET
def control_device(control_url, service_type, action, args):
# Build SOAP envelope
soap_body = f'''<?xml version="1.0"?>
<s:Envelope
xmlns:s="http://schemas.xmlsoap.org/soap/envelope/"
s:encodingStyle="http://schemas.xmlsoap.org/soap/encoding/">
<s:Body>
<u:{action} xmlns:u="{service_type}">
{''.join(f'<{k}>{v}</{k}>' for k, v in args.items())}
</u:{action}>
</s:Body>
</s:Envelope>'''
headers = {
'Content-Type': 'text/xml; charset="utf-8"',
'SOAPAction': f'"{service_type}#{action}"'
}
response = requests.post(control_url, data=soap_body, headers=headers)
return response.text
# Usage
control_url = 'http://192.168.1.100:8080/service/AVTransport/control'
service_type = 'urn:schemas-upnp-org:service:AVTransport:1'
action = 'Play'
args = {'InstanceID': '0', 'Speed': '1'}
result = control_device(control_url, service_type, action, args)
print(result)
UPnP Tools
Command Line Tools
# upnpc (miniupnpc)
# Install: apt-get install miniupnpc
# Discover IGD devices
upnpc -l
# Get external IP
upnpc -s
# Add port mapping
upnpc -a 192.168.1.50 8080 8080 TCP
# List port mappings
upnpc -L
# Delete port mapping
upnpc -d 8080 TCP
GUI Tools
- UPnP Inspector (Linux)
- UPnP Test Tool (Windows)
- Device Spy (UPnP Forum)
UPnP Security Issues
Major Vulnerabilities
1. No Authentication
Any device can control any other device
No password required
No encryption
Attack: Malicious app opens ports in router
2. Port Forwarding Abuse
Malware can:
- Open ports in router
- Expose internal services
- Create backdoors
Example:
Malware opens port 3389 (RDP)
Attacker can remotely access PC
3. SSDP Amplification DDoS
Attacker spoofs source IP as victim
Sends M-SEARCH to many UPnP devices
Devices respond to victim
Victim overwhelmed with traffic
Amplification factor: 30x-50x
4. XML External Entity (XXE)
Malicious device description:
<!DOCTYPE foo [
<!ENTITY xxe SYSTEM "file:///etc/passwd">
]>
<root>&xxe;</root>
Can read local files
Server-side request forgery
Security Best Practices
1. Disable UPnP on router
- If not needed, turn it off
- Most secure option
2. Use UPnP-UP (UPnP with User Profile)
- Authentication layer
- Access control
3. Firewall rules
- Block SSDP multicast from WAN
- Limit UPnP to trusted VLANs
4. Whitelist devices
- Only allow known devices
- MAC address filtering
5. Monitor port mappings
- Regular audits
- Alert on unexpected changes
6. Update firmware
- Patch vulnerabilities
- Keep devices current
UPnP vs Alternatives
vs Manual Port Forwarding
UPnP:
Pros: Automatic, easy
Cons: Security risk, no control
Manual:
Pros: Secure, controlled
Cons: Technical knowledge required
vs NAT-PMP / PCP
NAT-PMP (Apple):
- Similar to UPnP
- Simpler protocol
- Better security
PCP (Port Control Protocol):
- Successor to NAT-PMP
- IETF standard
- IPv6 support
vs STUN/TURN
UPnP: Local network discovery and control
STUN/TURN: NAT traversal for P2P connections
Different use cases, can complement each other
ELI10
UPnP is like devices introducing themselves and asking for help:
Discovery (Meeting New Friends):
New TV joins network:
TV: "Hi everyone! I'm a TV and can play videos!"
All devices hear the announcement
Your phone: "Cool, I found a TV!"
Control (Asking for Favors):
Phone to TV: "Can you play this video?"
TV: "Sure! Playing now."
Gaming console to router: "Can you open port 3478?"
Router: "Done! Port is open."
Problems (Security Issues):
Bad actor: "Hey router, open all ports!"
Router: "OK!" (No questions asked)
→ This is dangerous!
Better approach:
Router: "Who are you? Do you have permission?"
Bad actor: "Uh... never mind."
When to Use:
- Home media streaming
- Gaming (automatic port opening)
- Smart home devices
- Printing
When to Disable:
- Public networks
- When security is critical
- Enterprise environments
- If you don't need it
Rule of Thumb:
- Home network: Convenient (but understand risks)
- Business network: Usually disable
- Gaming: Helpful for matchmaking
- Important: Monitor what ports get opened!
Further Resources
- UPnP Forum
- RFC 6970 - UPnP IGD-PCP Interworking
- UPnP Device Architecture
- miniupnpc Library
- Security Concerns
WebSocket
Overview
WebSocket is a communication protocol that provides full-duplex communication channels over a single TCP connection. It enables real-time, bidirectional communication between a client and server with low overhead, making it ideal for interactive web applications.
Key Characteristics
Protocol: ws:// (unencrypted) or wss:// (encrypted)
Port: 80 (ws) or 443 (wss)
Transport: TCP
Connection: Long-lived, persistent
Communication: Full-duplex (bidirectional)
Latency: Low (no HTTP overhead after handshake)
Overhead: 2-14 bytes per frame
Status: RFC 6455 (2011)
Benefits:
✓ Real-time bidirectional communication
✓ Low latency (no polling overhead)
✓ Efficient (minimal frame overhead)
✓ Server can push data to client
✓ Single TCP connection
✓ Works through proxies and firewalls
✓ Subprotocol support
WebSocket vs Alternatives
HTTP Polling
Traditional HTTP Request/Response:
Client Server
| |
|──── HTTP GET (new data?) ─────>|
| |
|<─── HTTP Response (no) ────────|
| |
[wait 1 second]
| |
|──── HTTP GET (new data?) ─────>|
| |
|<─── HTTP Response (yes!) ──────|
| |
Problems:
- High latency (constant polling)
- Wasted requests (most return nothing)
- Server load (many unnecessary requests)
- HTTP overhead on every request
Long Polling
HTTP Long Polling:
Client Server
| |
|──── HTTP GET (wait) ──────────>|
| | [server holds request]
| | [data arrives]
|<─── HTTP Response (data!) ─────|
| |
|──── HTTP GET (wait) ──────────>|
| |
Better, but:
- Still HTTP overhead
- Reconnect after each message
- Server must handle many pending connections
- Not truly bidirectional
Server-Sent Events (SSE)
Server-Sent Events:
Client Server
| |
|──── HTTP GET (subscribe) ─────>|
| |
|<═══ Event stream ══════════════| (one-way)
|<═══ data: message 1 ═══════════|
|<═══ data: message 2 ═══════════|
|<═══ data: message 3 ═══════════|
| |
Good for:
✓ Server → Client only
✓ Text-based data
✓ Auto-reconnect
✓ Simpler than WebSocket
Limited:
✗ One-way only (server to client)
✗ HTTP/1.1 connection limit (6 per domain)
✗ Text only (no binary)
WebSocket
WebSocket:
Client Server
| |
|──── HTTP Upgrade ─────────────>|
|<─── 101 Switching Protocols ───|
| |
|<══════ WebSocket Open ═════════>|
| |
|──── Message 1 ────────────────>|
|<─── Message 2 ─────────────────|
|──── Message 3 ────────────────>|
|──── Message 4 ────────────────>|
|<─── Message 5 ─────────────────|
| |
Best for:
✓ Bidirectional communication
✓ Real-time updates
✓ Low latency required
✓ High message frequency
✓ Binary data support
WebSocket Protocol
Connection Handshake
WebSocket starts with an HTTP upgrade request:
Client Request:
GET /chat HTTP/1.1
Host: example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Sec-WebSocket-Version: 13
Origin: https://example.com
Server Response:
HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
Key Fields:
Upgrade: websocket
- Request protocol upgrade from HTTP to WebSocket
Connection: Upgrade
- Indicates connection upgrade needed
Sec-WebSocket-Key: <base64-encoded-random>
- 16-byte random value, base64 encoded
- Prevents caching proxies from confusing requests
Sec-WebSocket-Version: 13
- WebSocket protocol version (13 is current)
Sec-WebSocket-Accept: <computed-hash>
- Server proves it understands WebSocket
- Computed as: base64(SHA-1(Key + magic-string))
- Magic string: 258EAFA5-E914-47DA-95CA-C5AB0DC85B11
Origin: https://example.com
- Browser sends origin for CORS check
- Server can validate allowed origins
After handshake:
- HTTP connection becomes WebSocket connection
- Both sides can send messages anytime
- Connection stays open until explicitly closed
Handshake Validation
// Server-side validation (conceptual)
const crypto = require('crypto');
function computeAcceptKey(clientKey) {
const MAGIC = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11';
const hash = crypto
.createHash('sha1')
.update(clientKey + MAGIC)
.digest('base64');
return hash;
}
// Example:
const clientKey = 'dGhlIHNhbXBsZSBub25jZQ==';
const acceptKey = computeAcceptKey(clientKey);
console.log(acceptKey);
// Output: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
Frame Format
After handshake, data is sent in frames:
WebSocket Frame Structure:
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-------+-+-------------+-------------------------------+
|F|R|R|R| opcode|M| Payload len | Extended payload length |
|I|S|S|S| (4) |A| (7) | (16/64) |
|N|V|V|V| |S| | (if payload len==126/127) |
| |1|2|3| |K| | |
+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
| Extended payload length continued, if payload len == 127 |
+ - - - - - - - - - - - - - - - +-------------------------------+
| | Masking-key, if MASK set to 1 |
+-------------------------------+-------------------------------+
| Masking-key (continued) | Payload Data |
+-------------------------------- - - - - - - - - - - - - - - - +
: Payload Data continued ... :
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
| Payload Data continued ... |
+---------------------------------------------------------------+
Fields:
FIN (1 bit):
- 1 = final fragment
- 0 = more fragments coming
RSV1, RSV2, RSV3 (3 bits):
- Reserved for extensions
- Must be 0 unless extension negotiated
Opcode (4 bits):
- 0x0 = Continuation frame
- 0x1 = Text frame (UTF-8)
- 0x2 = Binary frame
- 0x8 = Connection close
- 0x9 = Ping
- 0xA = Pong
MASK (1 bit):
- 1 = payload is masked (required for client → server)
- 0 = payload not masked (server → client)
Payload Length (7 bits, or 7+16, or 7+64):
- 0-125: actual length
- 126: next 16 bits contain length
- 127: next 64 bits contain length
Masking Key (32 bits):
- Present if MASK = 1
- Random 4-byte key
- Client must mask all frames to server
Payload Data:
- Actual message data
- If masked, XOR with masking key
Minimum Frame Size:
- 2 bytes (no masking, payload ≤ 125 bytes)
- 6 bytes (with masking, payload ≤ 125 bytes)
Message Types
Text Frame (Opcode 0x1):
- UTF-8 encoded text
- Most common for JSON, strings
Binary Frame (Opcode 0x2):
- Raw binary data
- Images, files, protocol buffers
Ping Frame (Opcode 0x9):
- Sent by either side
- Keep connection alive
- Check if peer responsive
Pong Frame (Opcode 0xA):
- Response to ping
- Sent automatically
- Contains same data as ping
Close Frame (Opcode 0x8):
- Initiates connection close
- Contains optional close code and reason
- Peer responds with close frame
Client-Side Implementation
JavaScript (Browser)
// Create WebSocket connection
const socket = new WebSocket('ws://localhost:8080');
// Alternative: secure WebSocket
// const socket = new WebSocket('wss://example.com/socket');
// Connection opened
socket.addEventListener('open', (event) => {
console.log('Connected to server');
// Send message
socket.send('Hello Server!');
// Send JSON
socket.send(JSON.stringify({
type: 'chat',
message: 'Hello!',
timestamp: Date.now()
}));
// Send binary data
const buffer = new Uint8Array([1, 2, 3, 4]);
socket.send(buffer);
});
// Receive message
socket.addEventListener('message', (event) => {
console.log('Message from server:', event.data);
// Handle text data
if (typeof event.data === 'string') {
try {
const data = JSON.parse(event.data);
handleMessage(data);
} catch (e) {
console.log('Text:', event.data);
}
}
// Handle binary data
if (event.data instanceof Blob) {
event.data.arrayBuffer().then(buffer => {
const view = new Uint8Array(buffer);
console.log('Binary data:', view);
});
}
// Or receive as ArrayBuffer
// socket.binaryType = 'arraybuffer';
});
// Connection closed
socket.addEventListener('close', (event) => {
console.log('Disconnected from server');
console.log('Code:', event.code);
console.log('Reason:', event.reason);
console.log('Clean:', event.wasClean);
});
// Connection error
socket.addEventListener('error', (error) => {
console.error('WebSocket error:', error);
});
// Send messages
function sendMessage(text) {
if (socket.readyState === WebSocket.OPEN) {
socket.send(text);
} else {
console.error('WebSocket not connected');
}
}
// Close connection
function closeConnection() {
socket.close(1000, 'User closed connection');
}
// WebSocket states
console.log('CONNECTING:', WebSocket.CONNECTING); // 0
console.log('OPEN:', WebSocket.OPEN); // 1
console.log('CLOSING:', WebSocket.CLOSING); // 2
console.log('CLOSED:', WebSocket.CLOSED); // 3
// Check current state
console.log('Current state:', socket.readyState);
Advanced Client Features
class WebSocketClient {
constructor(url, options = {}) {
this.url = url;
this.options = {
reconnect: true,
reconnectInterval: 1000,
reconnectDecay: 1.5,
maxReconnectInterval: 30000,
maxReconnectAttempts: 10,
...options
};
this.ws = null;
this.reconnectAttempts = 0;
this.messageQueue = [];
this.handlers = new Map();
this.connect();
}
connect() {
this.ws = new WebSocket(this.url);
this.ws.onopen = () => {
console.log('Connected');
this.reconnectAttempts = 0;
// Send queued messages
while (this.messageQueue.length > 0) {
this.send(this.messageQueue.shift());
}
this.emit('connect');
};
this.ws.onmessage = (event) => {
try {
const data = JSON.parse(event.data);
this.emit(data.type || 'message', data);
} catch (e) {
this.emit('message', event.data);
}
};
this.ws.onclose = (event) => {
console.log('Disconnected:', event.code, event.reason);
this.emit('disconnect', event);
if (this.options.reconnect) {
this.reconnect();
}
};
this.ws.onerror = (error) => {
console.error('WebSocket error:', error);
this.emit('error', error);
};
}
reconnect() {
if (this.reconnectAttempts >= this.options.maxReconnectAttempts) {
console.error('Max reconnect attempts reached');
this.emit('reconnect_failed');
return;
}
this.reconnectAttempts++;
const delay = Math.min(
this.options.reconnectInterval *
Math.pow(this.options.reconnectDecay, this.reconnectAttempts - 1),
this.options.maxReconnectInterval
);
console.log(`Reconnecting in ${delay}ms (attempt ${this.reconnectAttempts})`);
setTimeout(() => {
this.emit('reconnecting', this.reconnectAttempts);
this.connect();
}, delay);
}
send(data) {
if (this.ws.readyState === WebSocket.OPEN) {
const message = typeof data === 'string'
? data
: JSON.stringify(data);
this.ws.send(message);
} else {
console.log('Queueing message (not connected)');
this.messageQueue.push(data);
}
}
on(event, handler) {
if (!this.handlers.has(event)) {
this.handlers.set(event, []);
}
this.handlers.get(event).push(handler);
}
emit(event, data) {
if (this.handlers.has(event)) {
this.handlers.get(event).forEach(handler => handler(data));
}
}
close() {
this.options.reconnect = false;
if (this.ws) {
this.ws.close(1000, 'Client closed');
}
}
}
// Usage
const client = new WebSocketClient('ws://localhost:8080', {
reconnect: true,
maxReconnectAttempts: 5
});
client.on('connect', () => {
console.log('Connected!');
client.send({ type: 'auth', token: 'abc123' });
});
client.on('message', (data) => {
console.log('Received:', data);
});
client.on('disconnect', () => {
console.log('Connection lost');
});
client.send({ type: 'chat', message: 'Hello' });
Server-Side Implementation
Node.js with 'ws' Library
const WebSocket = require('ws');
const http = require('http');
// Create HTTP server
const server = http.createServer((req, res) => {
res.writeHead(200);
res.end('WebSocket server running');
});
// Create WebSocket server
const wss = new WebSocket.Server({ server });
// Track connected clients
const clients = new Set();
// Connection handler
wss.on('connection', (ws, req) => {
console.log('Client connected from', req.socket.remoteAddress);
// Add to client set
clients.add(ws);
// Send welcome message
ws.send(JSON.stringify({
type: 'welcome',
message: 'Connected to server',
clients: clients.size
}));
// Broadcast new connection to all clients
broadcast({
type: 'user-joined',
clients: clients.size
}, ws);
// Message handler
ws.on('message', (data) => {
console.log('Received:', data.toString());
try {
const message = JSON.parse(data);
// Handle different message types
switch (message.type) {
case 'chat':
// Broadcast chat message
broadcast({
type: 'chat',
message: message.message,
timestamp: Date.now()
});
break;
case 'ping':
// Respond to ping
ws.send(JSON.stringify({
type: 'pong',
timestamp: Date.now()
}));
break;
default:
console.log('Unknown message type:', message.type);
}
} catch (e) {
console.error('Invalid JSON:', e);
}
});
// Pong handler (heartbeat)
ws.on('pong', () => {
ws.isAlive = true;
});
// Close handler
ws.on('close', (code, reason) => {
console.log('Client disconnected:', code, reason.toString());
clients.delete(ws);
// Notify others
broadcast({
type: 'user-left',
clients: clients.size
});
});
// Error handler
ws.on('error', (error) => {
console.error('WebSocket error:', error);
});
// Mark as alive for heartbeat
ws.isAlive = true;
});
// Broadcast to all clients
function broadcast(data, exclude = null) {
const message = JSON.stringify(data);
clients.forEach(client => {
if (client !== exclude && client.readyState === WebSocket.OPEN) {
client.send(message);
}
});
}
// Heartbeat (detect dead connections)
const heartbeatInterval = setInterval(() => {
clients.forEach(ws => {
if (!ws.isAlive) {
console.log('Terminating dead connection');
ws.terminate();
clients.delete(ws);
return;
}
ws.isAlive = false;
ws.ping();
});
}, 30000); // Every 30 seconds
// Cleanup on server close
wss.on('close', () => {
clearInterval(heartbeatInterval);
});
// Start server
const PORT = 8080;
server.listen(PORT, () => {
console.log(`WebSocket server listening on port ${PORT}`);
});
Advanced Server Features
const WebSocket = require('ws');
const http = require('http');
const url = require('url');
class WebSocketServer {
constructor(options = {}) {
this.options = {
port: 8080,
pingInterval: 30000,
maxClients: 1000,
...options
};
this.server = http.createServer();
this.wss = new WebSocket.Server({ server: this.server });
this.rooms = new Map(); // roomId -> Set of clients
this.clients = new Map(); // ws -> client info
this.setupHandlers();
this.startHeartbeat();
}
setupHandlers() {
this.wss.on('connection', (ws, req) => {
// Check max clients
if (this.clients.size >= this.options.maxClients) {
ws.close(1008, 'Server full');
return;
}
// Parse URL parameters
const params = url.parse(req.url, true).query;
// Create client info
const clientInfo = {
id: this.generateId(),
ip: req.socket.remoteAddress,
rooms: new Set(),
authenticated: false,
metadata: {}
};
this.clients.set(ws, clientInfo);
ws.isAlive = true;
console.log(`Client ${clientInfo.id} connected`);
// Send client ID
this.send(ws, {
type: 'connected',
clientId: clientInfo.id
});
// Message handler
ws.on('message', (data) => {
this.handleMessage(ws, data);
});
// Pong handler
ws.on('pong', () => {
ws.isAlive = true;
});
// Close handler
ws.on('close', () => {
this.handleDisconnect(ws);
});
// Error handler
ws.on('error', (error) => {
console.error('Error:', error);
});
});
}
handleMessage(ws, data) {
const client = this.clients.get(ws);
if (!client) return;
try {
const message = JSON.parse(data);
switch (message.type) {
case 'auth':
this.handleAuth(ws, message);
break;
case 'join-room':
this.joinRoom(ws, message.room);
break;
case 'leave-room':
this.leaveRoom(ws, message.room);
break;
case 'message':
this.handleRoomMessage(ws, message);
break;
default:
console.log('Unknown message type:', message.type);
}
} catch (e) {
console.error('Invalid message:', e);
this.send(ws, {
type: 'error',
message: 'Invalid message format'
});
}
}
handleAuth(ws, message) {
const client = this.clients.get(ws);
// Validate token (simplified)
if (message.token === 'valid-token') {
client.authenticated = true;
client.metadata.username = message.username;
this.send(ws, {
type: 'auth-success',
username: message.username
});
} else {
this.send(ws, {
type: 'auth-failed',
message: 'Invalid token'
});
ws.close(1008, 'Authentication failed');
}
}
joinRoom(ws, roomId) {
const client = this.clients.get(ws);
if (!client?.authenticated) return;
// Create room if doesn't exist
if (!this.rooms.has(roomId)) {
this.rooms.set(roomId, new Set());
}
// Add client to room
this.rooms.get(roomId).add(ws);
client.rooms.add(roomId);
console.log(`Client ${client.id} joined room ${roomId}`);
// Notify client
this.send(ws, {
type: 'joined-room',
room: roomId,
members: this.rooms.get(roomId).size
});
// Notify room members
this.broadcastToRoom(roomId, {
type: 'user-joined',
userId: client.id,
username: client.metadata.username,
members: this.rooms.get(roomId).size
}, ws);
}
leaveRoom(ws, roomId) {
const client = this.clients.get(ws);
if (!client) return;
if (this.rooms.has(roomId)) {
this.rooms.get(roomId).delete(ws);
client.rooms.delete(roomId);
// Notify others
this.broadcastToRoom(roomId, {
type: 'user-left',
userId: client.id,
members: this.rooms.get(roomId).size
});
// Clean up empty rooms
if (this.rooms.get(roomId).size === 0) {
this.rooms.delete(roomId);
}
}
}
handleRoomMessage(ws, message) {
const client = this.clients.get(ws);
if (!client?.authenticated) return;
if (message.room && this.rooms.has(message.room)) {
this.broadcastToRoom(message.room, {
type: 'message',
userId: client.id,
username: client.metadata.username,
message: message.content,
timestamp: Date.now()
});
}
}
handleDisconnect(ws) {
const client = this.clients.get(ws);
if (!client) return;
console.log(`Client ${client.id} disconnected`);
// Remove from all rooms
client.rooms.forEach(roomId => {
this.leaveRoom(ws, roomId);
});
// Remove from clients
this.clients.delete(ws);
}
send(ws, data) {
if (ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify(data));
}
}
broadcastToRoom(roomId, data, exclude = null) {
if (!this.rooms.has(roomId)) return;
const message = JSON.stringify(data);
this.rooms.get(roomId).forEach(client => {
if (client !== exclude && client.readyState === WebSocket.OPEN) {
client.send(message);
}
});
}
broadcastToAll(data, exclude = null) {
const message = JSON.stringify(data);
this.clients.forEach((clientInfo, ws) => {
if (ws !== exclude && ws.readyState === WebSocket.OPEN) {
ws.send(message);
}
});
}
startHeartbeat() {
this.heartbeatInterval = setInterval(() => {
this.clients.forEach((clientInfo, ws) => {
if (!ws.isAlive) {
console.log(`Terminating dead connection: ${clientInfo.id}`);
ws.terminate();
return;
}
ws.isAlive = false;
ws.ping();
});
}, this.options.pingInterval);
}
generateId() {
return Math.random().toString(36).substring(2, 15);
}
start() {
this.server.listen(this.options.port, () => {
console.log(`WebSocket server listening on port ${this.options.port}`);
});
}
stop() {
clearInterval(this.heartbeatInterval);
this.wss.close();
this.server.close();
}
}
// Usage
const server = new WebSocketServer({
port: 8080,
pingInterval: 30000,
maxClients: 1000
});
server.start();
Use Cases
1. Chat Application
// Client
class ChatClient {
constructor(url) {
this.socket = new WebSocket(url);
this.setupHandlers();
}
setupHandlers() {
this.socket.onopen = () => {
console.log('Connected to chat');
this.authenticate();
};
this.socket.onmessage = (event) => {
const data = JSON.parse(event.data);
switch (data.type) {
case 'message':
this.displayMessage(data);
break;
case 'user-joined':
this.showNotification(`${data.username} joined`);
break;
case 'user-left':
this.showNotification(`${data.username} left`);
break;
}
};
}
authenticate() {
this.socket.send(JSON.stringify({
type: 'auth',
token: localStorage.getItem('token'),
username: localStorage.getItem('username')
}));
}
joinRoom(roomId) {
this.socket.send(JSON.stringify({
type: 'join-room',
room: roomId
}));
}
sendMessage(roomId, message) {
this.socket.send(JSON.stringify({
type: 'message',
room: roomId,
content: message
}));
}
displayMessage(data) {
const messageElement = document.createElement('div');
messageElement.className = 'message';
messageElement.innerHTML = `
<span class="username">${data.username}:</span>
<span class="content">${data.message}</span>
<span class="timestamp">${new Date(data.timestamp).toLocaleTimeString()}</span>
`;
document.getElementById('messages').appendChild(messageElement);
}
showNotification(text) {
console.log(text);
}
}
const chat = new ChatClient('ws://localhost:8080');
chat.joinRoom('general');
2. Real-Time Dashboard
// Server: Push updates to dashboard
function broadcastMetrics() {
const metrics = {
type: 'metrics',
cpu: getCpuUsage(),
memory: getMemoryUsage(),
activeUsers: clients.size,
requestsPerSecond: getRequestRate(),
timestamp: Date.now()
};
broadcastToAll(metrics);
}
setInterval(broadcastMetrics, 1000);
// Client: Display real-time metrics
socket.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.type === 'metrics') {
updateChart('cpu', data.cpu);
updateChart('memory', data.memory);
updateCounter('users', data.activeUsers);
updateCounter('rps', data.requestsPerSecond);
}
};
3. Live Notifications
// Server: Send notifications
function notifyUser(userId, notification) {
const client = getUserWebSocket(userId);
if (client && client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify({
type: 'notification',
title: notification.title,
message: notification.message,
priority: notification.priority,
timestamp: Date.now()
}));
}
}
// Client: Display notifications
socket.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.type === 'notification') {
showNotification(data.title, data.message);
// Play sound for high priority
if (data.priority === 'high') {
playNotificationSound();
}
// Desktop notification
if (Notification.permission === 'granted') {
new Notification(data.title, {
body: data.message,
icon: '/icon.png'
});
}
}
};
4. Collaborative Editing
// Server: Broadcast document changes
wss.on('connection', (ws) => {
ws.on('message', (data) => {
const change = JSON.parse(data);
if (change.type === 'edit') {
// Apply change to document
applyChange(change.documentId, change.operation);
// Broadcast to others in same document
broadcastToDocument(change.documentId, {
type: 'edit',
operation: change.operation,
userId: ws.userId
}, ws);
}
});
});
// Client: Send and receive edits
let editor = document.getElementById('editor');
editor.addEventListener('input', debounce((e) => {
socket.send(JSON.stringify({
type: 'edit',
documentId: currentDocId,
operation: {
type: 'insert',
position: e.target.selectionStart,
text: e.data
}
}));
}, 100));
socket.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.type === 'edit' && data.userId !== myUserId) {
applyRemoteEdit(data.operation);
}
};
5. Gaming/Multiplayer
// Server: Game state synchronization
const gameState = {
players: new Map(),
entities: []
};
function updateGameState() {
broadcastToAll({
type: 'state',
players: Array.from(gameState.players.values()),
entities: gameState.entities,
timestamp: Date.now()
});
}
// 60 updates per second
setInterval(updateGameState, 1000 / 60);
// Client: Send player input
const input = {
keys: {},
mouse: { x: 0, y: 0 }
};
document.addEventListener('keydown', (e) => {
input.keys[e.key] = true;
socket.send(JSON.stringify({
type: 'input',
keys: input.keys,
timestamp: Date.now()
}));
});
// Receive game state
socket.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.type === 'state') {
renderGameState(data.players, data.entities);
}
};
Security
Authentication
// Server: Verify token on connection
wss.on('connection', (ws, req) => {
// Extract token from query string
const params = new URLSearchParams(req.url.split('?')[1]);
const token = params.get('token');
// Verify token
if (!verifyToken(token)) {
ws.close(1008, 'Invalid authentication');
return;
}
ws.userId = decodeToken(token).userId;
});
// Or: Authenticate after connection
ws.on('message', (data) => {
const message = JSON.parse(data);
if (message.type === 'auth') {
if (verifyToken(message.token)) {
ws.authenticated = true;
ws.userId = decodeToken(message.token).userId;
ws.send(JSON.stringify({ type: 'auth-success' }));
} else {
ws.close(1008, 'Authentication failed');
}
} else if (!ws.authenticated) {
ws.send(JSON.stringify({
type: 'error',
message: 'Not authenticated'
}));
}
});
Origin Validation
// Server: Validate origin
wss.on('connection', (ws, req) => {
const origin = req.headers.origin;
const allowedOrigins = [
'https://example.com',
'https://app.example.com'
];
if (!allowedOrigins.includes(origin)) {
console.log('Rejected connection from:', origin);
ws.close(1008, 'Origin not allowed');
return;
}
// Accept connection
});
Rate Limiting
// Server: Rate limit messages
const rateLimits = new Map(); // clientId -> message count
ws.on('message', (data) => {
const clientId = ws.userId || ws.ip;
if (!rateLimits.has(clientId)) {
rateLimits.set(clientId, { count: 0, resetAt: Date.now() + 60000 });
}
const limit = rateLimits.get(clientId);
// Reset if window expired
if (Date.now() > limit.resetAt) {
limit.count = 0;
limit.resetAt = Date.now() + 60000;
}
// Check limit (100 messages per minute)
if (limit.count >= 100) {
ws.send(JSON.stringify({
type: 'error',
message: 'Rate limit exceeded'
}));
return;
}
limit.count++;
// Process message
handleMessage(ws, data);
});
Input Validation
// Server: Validate and sanitize input
function handleMessage(ws, data) {
let message;
try {
message = JSON.parse(data);
} catch (e) {
ws.send(JSON.stringify({
type: 'error',
message: 'Invalid JSON'
}));
return;
}
// Validate message structure
if (!message.type || typeof message.type !== 'string') {
ws.send(JSON.stringify({
type: 'error',
message: 'Invalid message format'
}));
return;
}
// Validate message size
if (data.length > 10000) {
ws.send(JSON.stringify({
type: 'error',
message: 'Message too large'
}));
return;
}
// Sanitize text content
if (message.content) {
message.content = sanitizeHtml(message.content);
}
// Process validated message
processMessage(ws, message);
}
Secure WebSocket (wss://)
// Server: Use HTTPS/WSS
const https = require('https');
const fs = require('fs');
const server = https.createServer({
cert: fs.readFileSync('cert.pem'),
key: fs.readFileSync('key.pem')
});
const wss = new WebSocket.Server({ server });
server.listen(443);
// Client: Connect with wss://
const socket = new WebSocket('wss://example.com/socket');
Best Practices
1. Heartbeat/Ping-Pong
// Server: Detect dead connections
const heartbeatInterval = setInterval(() => {
wss.clients.forEach((ws) => {
if (ws.isAlive === false) {
return ws.terminate();
}
ws.isAlive = false;
ws.ping();
});
}, 30000);
wss.on('connection', (ws) => {
ws.isAlive = true;
ws.on('pong', () => {
ws.isAlive = true;
});
});
// Client: Respond to pings (automatic in browsers)
// Or implement custom heartbeat:
setInterval(() => {
socket.send(JSON.stringify({ type: 'ping' }));
}, 30000);
2. Reconnection Strategy
// Client: Exponential backoff
class ReconnectingWebSocket {
constructor(url) {
this.url = url;
this.reconnectDelay = 1000;
this.maxReconnectDelay = 30000;
this.reconnectAttempts = 0;
this.connect();
}
connect() {
this.ws = new WebSocket(this.url);
this.ws.onopen = () => {
console.log('Connected');
this.reconnectDelay = 1000;
this.reconnectAttempts = 0;
};
this.ws.onclose = () => {
console.log('Disconnected');
this.scheduleReconnect();
};
}
scheduleReconnect() {
const delay = Math.min(
this.reconnectDelay * Math.pow(2, this.reconnectAttempts),
this.maxReconnectDelay
);
console.log(`Reconnecting in ${delay}ms`);
setTimeout(() => {
this.reconnectAttempts++;
this.connect();
}, delay);
}
}
3. Message Queuing
// Client: Queue messages when disconnected
class QueuedWebSocket {
constructor(url) {
this.url = url;
this.queue = [];
this.connect();
}
connect() {
this.ws = new WebSocket(this.url);
this.ws.onopen = () => {
// Send queued messages
while (this.queue.length > 0) {
this.ws.send(this.queue.shift());
}
};
}
send(data) {
if (this.ws.readyState === WebSocket.OPEN) {
this.ws.send(data);
} else {
this.queue.push(data);
}
}
}
4. Binary Data
// Send binary efficiently
const buffer = new ArrayBuffer(8);
const view = new DataView(buffer);
view.setUint32(0, 12345);
view.setFloat32(4, 3.14);
socket.send(buffer);
// Receive binary
socket.binaryType = 'arraybuffer';
socket.onmessage = (event) => {
if (event.data instanceof ArrayBuffer) {
const view = new DataView(event.data);
const num = view.getUint32(0);
const float = view.getFloat32(4);
}
};
5. Compression
// Server: Enable per-message deflate
const wss = new WebSocket.Server({
server,
perMessageDeflate: {
zlibDeflateOptions: {
chunkSize: 1024,
memLevel: 7,
level: 3
},
zlibInflateOptions: {
chunkSize: 10 * 1024
},
clientNoContextTakeover: true,
serverNoContextTakeover: true,
serverMaxWindowBits: 10,
concurrencyLimit: 10,
threshold: 1024 // Compress only messages > 1KB
}
});
Debugging
Browser DevTools
// Chrome/Firefox DevTools
// Network tab → WS/Messages
// View frames
// - Sent (green arrow)
// - Received (red arrow)
// - Click to view content
// Console logging
const socket = new WebSocket('ws://localhost:8080');
socket.addEventListener('message', (event) => {
console.log('%c⬇ Received', 'color: blue', event.data);
});
socket.send = new Proxy(socket.send, {
apply(target, thisArg, args) {
console.log('%c⬆ Sent', 'color: green', args[0]);
return target.apply(thisArg, args);
}
});
Command-Line Tools
# wscat - WebSocket client
npm install -g wscat
# Connect to server
wscat -c ws://localhost:8080
# Send message
> {"type": "chat", "message": "Hello"}
# Listen for messages
< {"type": "message", "content": "Hi there"}
# WebSocket with headers
wscat -c ws://localhost:8080 -H "Authorization: Bearer token"
# Test wss:// with self-signed cert
wscat -c wss://localhost:443 -n
# websocat - More features
cargo install websocat
# Connect
websocat ws://localhost:8080
# Binary mode
websocat --binary ws://localhost:8080
# tcpdump - Capture WebSocket traffic
sudo tcpdump -i any -A 'tcp port 8080'
# Wireshark
# Filter: websocket
# Analyze → Decode As → WebSocket
Common Issues
Issue: Connection fails immediately
Causes:
- Wrong URL (ws:// vs wss://)
- Server not running
- Firewall blocking port
- CORS/Origin mismatch
Solution:
- Check server logs
- Verify URL and port
- Check browser console for errors
- Validate origin on server
Issue: Connection drops frequently
Causes:
- No heartbeat/ping
- Idle timeout
- Network issues
- Proxy timeout
Solution:
- Implement ping/pong
- Send periodic messages
- Reduce ping interval
- Use wss:// for better stability
Issue: Messages not received
Causes:
- Wrong readyState
- Connection closed
- Message too large
- Server not broadcasting
Solution:
- Check socket.readyState === OPEN
- Add message queuing
- Split large messages
- Verify server broadcast logic
Issue: High memory usage
Causes:
- Not closing connections
- Large message buffers
- Too many connections
- Memory leaks
Solution:
- Close unused connections
- Limit message size
- Set max connections
- Use heartbeat to detect dead connections
ELI10: WebSocket Explained Simply
WebSocket is like having a phone call instead of sending letters:
Traditional HTTP (Letters)
You: "Any new messages?" [wait for response]
Server: "No"
[1 second later]
You: "Any new messages?" [wait for response]
Server: "No"
[1 second later]
You: "Any new messages?" [wait for response]
Server: "Yes! Here's one"
Problem: Lots of wasted "letters" (requests)
WebSocket (Phone Call)
You: "Hello!" [open connection]
Server: "Hi!" [connection open]
[Connection stays open]
Server: "New message for you!"
You: "Thanks! Here's my reply"
Server: "Got it!"
You: "Question?"
Server: "Answer!"
Connection stays open until you hang up
Key Differences
HTTP:
- Ask → Wait → Answer → Close
- Repeat every time
- Like knocking on door for each question
WebSocket:
- Open door once
- Walk in and stay
- Talk back and forth
- Like having a conversation
Real Examples
HTTP: Checking email every minute
WebSocket: Email app shows new mail instantly
HTTP: Refreshing page to see chat messages
WebSocket: Messages appear as sent
HTTP: Reloading dashboard for new data
WebSocket: Dashboard updates in real-time
Further Resources
Specifications
Libraries
JavaScript (Client)
- Native WebSocket API (built-in)
- Socket.IO - High-level library with fallbacks
- SockJS - WebSocket emulation
Node.js (Server)
- ws - Fast, standards-compliant
- Socket.IO - Client + server library
- uWebSockets.js - Ultra fast
Python
- websockets - asyncio library
- aiohttp - WebSocket support
- Flask-SocketIO - Flask integration
Go
Rust
- tokio-tungstenite
- actix-web - WebSocket support
Tools
Testing
- WebSocket King - Online tester
- PieSocket - Testing tool
Books & Tutorials
WebRTC (Web Real-Time Communication)
Overview
WebRTC (Web Real-Time Communication) is an open-source framework that enables real-time peer-to-peer communication directly between web browsers and mobile applications. It supports video, audio, and arbitrary data transfer without requiring plugins or third-party software.
Key Features
1. Peer-to-Peer Communication
- Direct browser-to-browser connections
- Low latency (no server relay required*)
- Reduced bandwidth costs
2. Media Support
- Audio streaming
- Video streaming
- Screen sharing
- Data channels for arbitrary data
3. Built-in Security
- Mandatory encryption (DTLS, SRTP)
- No unencrypted media transmission
- Secure signaling required
4. NAT/Firewall Traversal
- ICE protocol for connectivity
- STUN for public address discovery
- TURN as relay fallback
5. Adaptive Quality
- Bandwidth estimation
- Codec negotiation
- Quality adjusts to network conditions
* Direct P2P when possible; TURN relay as fallback
WebRTC Architecture
┌─────────────────────────────────────────────────────────────┐
│ WebRTC Application │
│ (JavaScript API in browser or native mobile app) │
└────────────────────┬────────────────────────────────────────┘
│
┌───────────────┼───────────────┐
│ │ │
▼ ▼ ▼
┌─────────┐ ┌──────────┐ ┌──────────┐
│ Media │ │ Data │ │ Signaling│
│ Streams │ │ Channels │ │ (Custom) │
└─────────┘ └──────────┘ └──────────┘
│ │ │
▼ ▼ │
┌─────────────────────────┐ │
│ WebRTC Core APIs │ │
│ │ │
│ - getUserMedia() │ │
│ - RTCPeerConnection │ │
│ - RTCDataChannel │ │
└─────────────────────────┘ │
│ │
▼ │
┌─────────────────────────┐ │
│ ICE/STUN/TURN │ │
│ (NAT Traversal) │ │
└─────────────────────────┘ │
│ │
└───────────────┬───────────────┘
│
▼
┌──────────────────┐
│ Network Layer │
│ (UDP/TCP/TLS) │
└──────────────────┘
Core Components
1. getUserMedia API
Access local camera and microphone:
// Basic usage
async function getLocalMedia() {
try {
const stream = await navigator.mediaDevices.getUserMedia({
video: true,
audio: true
});
// Display local video
document.getElementById('localVideo').srcObject = stream;
return stream;
} catch (error) {
console.error('Error accessing media devices:', error);
}
}
// Advanced constraints
const constraints = {
video: {
width: { min: 640, ideal: 1280, max: 1920 },
height: { min: 480, ideal: 720, max: 1080 },
frameRate: { ideal: 30, max: 60 },
facingMode: 'user' // or 'environment' for rear camera
},
audio: {
echoCancellation: true,
noiseSuppression: true,
autoGainControl: true
}
};
const stream = await navigator.mediaDevices.getUserMedia(constraints);
// List available devices
const devices = await navigator.mediaDevices.enumerateDevices();
devices.forEach(device => {
console.log(`${device.kind}: ${device.label} (${device.deviceId})`);
});
// Screen sharing
const screenStream = await navigator.mediaDevices.getDisplayMedia({
video: {
cursor: 'always',
displaySurface: 'monitor' // 'window', 'application', 'browser'
},
audio: false
});
2. RTCPeerConnection
Core API for peer-to-peer connection:
// Create peer connection
const configuration = {
iceServers: [
{ urls: 'stun:stun.l.google.com:19302' },
{ urls: 'stun:stun1.l.google.com:19302' },
{
urls: 'turn:turn.example.com:3478',
username: 'user',
credential: 'pass'
}
],
iceCandidatePoolSize: 10
};
const peerConnection = new RTCPeerConnection(configuration);
// Add local stream to connection
localStream.getTracks().forEach(track => {
peerConnection.addTrack(track, localStream);
});
// Listen for remote stream
peerConnection.ontrack = (event) => {
const remoteVideo = document.getElementById('remoteVideo');
if (remoteVideo.srcObject !== event.streams[0]) {
remoteVideo.srcObject = event.streams[0];
console.log('Received remote stream');
}
};
// Handle ICE candidates
peerConnection.onicecandidate = (event) => {
if (event.candidate) {
// Send candidate to remote peer via signaling
sendToSignalingServer({
type: 'ice-candidate',
candidate: event.candidate
});
}
};
// Monitor connection state
peerConnection.onconnectionstatechange = () => {
console.log('Connection state:', peerConnection.connectionState);
// States: new, connecting, connected, disconnected, failed, closed
};
peerConnection.oniceconnectionstatechange = () => {
console.log('ICE state:', peerConnection.iceConnectionState);
// States: new, checking, connected, completed, failed, disconnected, closed
};
3. RTCDataChannel
Bi-directional data transfer:
// Sender creates data channel
const dataChannel = peerConnection.createDataChannel('chat', {
ordered: true, // Guarantee order
maxRetransmits: 3 // Retry failed messages 3 times
// OR: maxPacketLifeTime: 3000 // Drop after 3 seconds
});
dataChannel.onopen = () => {
console.log('Data channel opened');
dataChannel.send('Hello!');
};
dataChannel.onmessage = (event) => {
console.log('Received:', event.data);
};
dataChannel.onerror = (error) => {
console.error('Data channel error:', error);
};
dataChannel.onclose = () => {
console.log('Data channel closed');
};
// Receiver listens for data channel
peerConnection.ondatachannel = (event) => {
const receiveChannel = event.channel;
receiveChannel.onmessage = (event) => {
console.log('Received:', event.data);
};
receiveChannel.onopen = () => {
console.log('Receive channel opened');
};
};
// Send different data types
dataChannel.send('Text message');
dataChannel.send(JSON.stringify({ type: 'chat', message: 'Hi' }));
dataChannel.send(new Uint8Array([1, 2, 3, 4])); // Binary
dataChannel.send(new Blob(['file content'])); // Blob
// Check buffered amount before sending large data
if (dataChannel.bufferedAmount === 0) {
dataChannel.send(largeData);
}
Connection Establishment (Signaling)
WebRTC doesn't define signaling - you implement it yourself:
Offer/Answer Exchange (SDP)
// ============================================
// Caller (Initiator)
// ============================================
// 1. Create offer
const offer = await peerConnection.createOffer({
offerToReceiveAudio: true,
offerToReceiveVideo: true
});
// 2. Set local description
await peerConnection.setLocalDescription(offer);
// 3. Send offer to remote peer via signaling
sendToSignalingServer({
type: 'offer',
sdp: peerConnection.localDescription
});
// 4. Receive answer from signaling server
signalingSocket.on('answer', async (answer) => {
await peerConnection.setRemoteDescription(
new RTCSessionDescription(answer)
);
});
// ============================================
// Callee (Responder)
// ============================================
// 1. Receive offer from signaling server
signalingSocket.on('offer', async (offer) => {
// 2. Set remote description
await peerConnection.setRemoteDescription(
new RTCSessionDescription(offer)
);
// 3. Create answer
const answer = await peerConnection.createAnswer();
// 4. Set local description
await peerConnection.setLocalDescription(answer);
// 5. Send answer back via signaling
sendToSignalingServer({
type: 'answer',
sdp: peerConnection.localDescription
});
});
// ============================================
// Both Peers
// ============================================
// Handle ICE candidates
peerConnection.onicecandidate = (event) => {
if (event.candidate) {
sendToSignalingServer({
type: 'ice-candidate',
candidate: event.candidate
});
}
};
// Receive ICE candidates from signaling
signalingSocket.on('ice-candidate', async (candidate) => {
try {
await peerConnection.addIceCandidate(
new RTCIceCandidate(candidate)
);
} catch (error) {
console.error('Error adding ICE candidate:', error);
}
});
SDP (Session Description Protocol)
SDP describes the media session:
Example SDP Offer:
v=0
o=- 123456789 2 IN IP4 127.0.0.1
s=-
t=0 0
a=group:BUNDLE 0 1
a=msid-semantic: WMS stream1
m=audio 9 UDP/TLS/RTP/SAVPF 111 103 104
c=IN IP4 0.0.0.0
a=rtcp:9 IN IP4 0.0.0.0
a=ice-ufrag:F7gI
a=ice-pwd:x9cml6RvRClHPcAy
a=ice-options:trickle
a=fingerprint:sha-256 8B:87:09:8A:5D:C2:...
a=setup:actpass
a=mid:0
a=sendrecv
a=rtcp-mux
a=rtpmap:111 opus/48000/2
a=rtpmap:103 ISAC/16000
a=rtpmap:104 ISAC/32000
m=video 9 UDP/TLS/RTP/SAVPF 96 97 98
c=IN IP4 0.0.0.0
a=rtcp:9 IN IP4 0.0.0.0
a=ice-ufrag:F7gI
a=ice-pwd:x9cml6RvRClHPcAy
a=ice-options:trickle
a=fingerprint:sha-256 8B:87:09:8A:5D:C2:...
a=setup:actpass
a=mid:1
a=sendrecv
a=rtcp-mux
a=rtpmap:96 VP8/90000
a=rtpmap:97 VP9/90000
a=rtpmap:98 H264/90000
Key Fields:
- v=0: SDP version
- m=: Media description (audio/video)
- c=: Connection information
- a=: Attributes (ICE, codecs, etc.)
- rtpmap: RTP payload mapping
- ice-ufrag/ice-pwd: ICE credentials
- fingerprint: DTLS certificate fingerprint
Signaling Implementation Examples
WebSocket Signaling Server (Node.js)
// Server
const WebSocket = require('ws');
const wss = new WebSocket.Server({ port: 8080 });
const rooms = new Map(); // roomId -> Set of clients
wss.on('connection', (ws) => {
console.log('Client connected');
ws.on('message', (data) => {
const message = JSON.parse(data);
switch (message.type) {
case 'join':
// Join room
if (!rooms.has(message.room)) {
rooms.set(message.room, new Set());
}
rooms.get(message.room).add(ws);
ws.room = message.room;
// Notify others in room
broadcast(message.room, ws, {
type: 'user-joined',
userId: message.userId
});
break;
case 'offer':
case 'answer':
case 'ice-candidate':
// Forward to specific peer or broadcast
if (message.target) {
sendToUser(message.target, message);
} else {
broadcast(ws.room, ws, message);
}
break;
case 'leave':
leaveRoom(ws);
break;
}
});
ws.on('close', () => {
console.log('Client disconnected');
leaveRoom(ws);
});
});
function broadcast(room, sender, message) {
if (!rooms.has(room)) return;
rooms.get(room).forEach(client => {
if (client !== sender && client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(message));
}
});
}
function leaveRoom(ws) {
if (ws.room && rooms.has(ws.room)) {
rooms.get(ws.room).delete(ws);
broadcast(ws.room, ws, {
type: 'user-left',
userId: ws.userId
});
}
}
console.log('Signaling server running on ws://localhost:8080');
Client-Side Signaling
// Client
class SignalingClient {
constructor(url) {
this.socket = new WebSocket(url);
this.handlers = new Map();
this.socket.onmessage = (event) => {
const message = JSON.parse(event.data);
const handler = this.handlers.get(message.type);
if (handler) {
handler(message);
}
};
this.socket.onopen = () => {
console.log('Signaling connected');
};
this.socket.onerror = (error) => {
console.error('Signaling error:', error);
};
this.socket.onclose = () => {
console.log('Signaling disconnected');
};
}
on(type, handler) {
this.handlers.set(type, handler);
}
send(message) {
this.socket.send(JSON.stringify(message));
}
join(room, userId) {
this.send({ type: 'join', room, userId });
}
sendOffer(offer, target) {
this.send({ type: 'offer', sdp: offer, target });
}
sendAnswer(answer, target) {
this.send({ type: 'answer', sdp: answer, target });
}
sendIceCandidate(candidate, target) {
this.send({ type: 'ice-candidate', candidate, target });
}
}
// Usage
const signaling = new SignalingClient('ws://localhost:8080');
signaling.on('offer', handleOffer);
signaling.on('answer', handleAnswer);
signaling.on('ice-candidate', handleIceCandidate);
signaling.join('room123', 'user1');
Complete WebRTC Example
Simple Video Chat Application
class WebRTCVideoChat {
constructor(signalingUrl) {
this.signaling = new SignalingClient(signalingUrl);
this.peerConnection = null;
this.localStream = null;
this.setupSignaling();
}
setupSignaling() {
this.signaling.on('offer', async (message) => {
await this.handleOffer(message.sdp, message.sender);
});
this.signaling.on('answer', async (message) => {
await this.handleAnswer(message.sdp);
});
this.signaling.on('ice-candidate', async (message) => {
await this.handleIceCandidate(message.candidate);
});
this.signaling.on('user-joined', (message) => {
console.log('User joined:', message.userId);
// Initiate call if you're the caller
});
}
async start(localVideoElement, remoteVideoElement) {
// Get local media
this.localStream = await navigator.mediaDevices.getUserMedia({
video: { width: 1280, height: 720 },
audio: true
});
localVideoElement.srcObject = this.localStream;
// Create peer connection
this.peerConnection = new RTCPeerConnection({
iceServers: [
{ urls: 'stun:stun.l.google.com:19302' }
]
});
// Add local stream
this.localStream.getTracks().forEach(track => {
this.peerConnection.addTrack(track, this.localStream);
});
// Handle remote stream
this.peerConnection.ontrack = (event) => {
remoteVideoElement.srcObject = event.streams[0];
};
// Handle ICE candidates
this.peerConnection.onicecandidate = (event) => {
if (event.candidate) {
this.signaling.sendIceCandidate(event.candidate);
}
};
// Monitor connection
this.peerConnection.onconnectionstatechange = () => {
console.log('Connection state:',
this.peerConnection.connectionState);
};
}
async call() {
// Create and send offer
const offer = await this.peerConnection.createOffer();
await this.peerConnection.setLocalDescription(offer);
this.signaling.sendOffer(offer);
}
async handleOffer(offer, sender) {
await this.peerConnection.setRemoteDescription(
new RTCSessionDescription(offer)
);
const answer = await this.peerConnection.createAnswer();
await this.peerConnection.setLocalDescription(answer);
this.signaling.sendAnswer(answer, sender);
}
async handleAnswer(answer) {
await this.peerConnection.setRemoteDescription(
new RTCSessionDescription(answer)
);
}
async handleIceCandidate(candidate) {
await this.peerConnection.addIceCandidate(
new RTCIceCandidate(candidate)
);
}
hangup() {
if (this.peerConnection) {
this.peerConnection.close();
this.peerConnection = null;
}
if (this.localStream) {
this.localStream.getTracks().forEach(track => track.stop());
this.localStream = null;
}
}
toggleAudio() {
const audioTrack = this.localStream.getAudioTracks()[0];
audioTrack.enabled = !audioTrack.enabled;
return audioTrack.enabled;
}
toggleVideo() {
const videoTrack = this.localStream.getVideoTracks()[0];
videoTrack.enabled = !videoTrack.enabled;
return videoTrack.enabled;
}
}
// Usage
const chat = new WebRTCVideoChat('ws://localhost:8080');
const localVideo = document.getElementById('localVideo');
const remoteVideo = document.getElementById('remoteVideo');
await chat.start(localVideo, remoteVideo);
chat.signaling.join('room123', 'user1');
// When ready to call
document.getElementById('callButton').onclick = () => chat.call();
document.getElementById('hangupButton').onclick = () => chat.hangup();
document.getElementById('muteButton').onclick = () => chat.toggleAudio();
document.getElementById('videoButton').onclick = () => chat.toggleVideo();
Media Codecs
Audio Codecs
Opus (Preferred)
- Bitrate: 6-510 kbps
- Latency: 5-66.5 ms
- Best quality and efficiency
- Supports stereo and mono
- Adaptive bitrate
G.711 (PCMU/PCMA)
- Bitrate: 64 kbps
- Latency: Low
- Widely supported
- Lower quality than Opus
iSAC
- Bitrate: 10-32 kbps
- Adaptive bitrate
- Good for low bandwidth
iLBC
- Bitrate: 13.33 or 15.2 kbps
- Packet loss resilience
- Voice only
Video Codecs
VP8 (Mandatory in WebRTC)
- Open source
- Good quality
- Hardware acceleration common
- Bitrate: 100-2000 kbps typically
VP9 (Better than VP8)
- 50% better compression than VP8
- Supports 4K
- Lower bandwidth usage
- Newer, less hardware support
H.264 (Most compatible)
- Patent-encumbered
- Excellent hardware support
- Multiple profiles (Baseline, Main, High)
- Most widely supported
AV1 (Future)
- Best compression
- Open source
- Still emerging
- Limited hardware support
Codec Selection
// Prefer specific codec
function preferCodec(sdp, codecName) {
const lines = sdp.split('\n');
const mLineIndex = lines.findIndex(line => line.startsWith('m=video'));
if (mLineIndex === -1) return sdp;
const codecRegex = new RegExp(`rtpmap:(\\d+) ${codecName}`, 'i');
const codecPayload = lines
.find(line => codecRegex.test(line))
?.match(codecRegex)?.[1];
if (!codecPayload) return sdp;
const mLine = lines[mLineIndex].split(' ');
const codecs = mLine.slice(3);
// Move preferred codec to front
const newCodecs = [
codecPayload,
...codecs.filter(c => c !== codecPayload)
];
mLine.splice(3, codecs.length, ...newCodecs);
lines[mLineIndex] = mLine.join(' ');
return lines.join('\n');
}
// Usage
const offer = await peerConnection.createOffer();
offer.sdp = preferCodec(offer.sdp, 'VP9');
await peerConnection.setLocalDescription(offer);
Quality Adaptation
Bandwidth Estimation
// Monitor bandwidth
peerConnection.getStats().then(stats => {
stats.forEach(report => {
if (report.type === 'candidate-pair' && report.state === 'succeeded') {
console.log('Available bandwidth:',
report.availableOutgoingBitrate);
console.log('Current bandwidth:',
report.currentRoundTripTime);
}
if (report.type === 'inbound-rtp' && report.mediaType === 'video') {
console.log('Bytes received:', report.bytesReceived);
console.log('Packets lost:', report.packetsLost);
console.log('Jitter:', report.jitter);
}
});
});
// Periodic monitoring
setInterval(async () => {
const stats = await peerConnection.getStats();
analyzeStats(stats);
}, 1000);
Simulcast (Multiple Qualities)
// Sender: Send multiple resolutions
const sender = peerConnection
.getSenders()
.find(s => s.track.kind === 'video');
const parameters = sender.getParameters();
if (!parameters.encodings) {
parameters.encodings = [
{ rid: 'h', maxBitrate: 1500000 }, // High quality
{ rid: 'm', maxBitrate: 600000, scaleResolutionDownBy: 2 }, // Medium
{ rid: 'l', maxBitrate: 200000, scaleResolutionDownBy: 4 } // Low
];
}
await sender.setParameters(parameters);
// Receiver: Select layer
const receiver = peerConnection
.getReceivers()
.find(r => r.track.kind === 'video');
// Request specific layer
receiver.getParameters().encodings = [
{ active: true, rid: 'm' } // Request medium quality
];
Manual Bitrate Control
async function setMaxBitrate(peerConnection, maxBitrate) {
const sender = peerConnection
.getSenders()
.find(s => s.track.kind === 'video');
const parameters = sender.getParameters();
if (!parameters.encodings) {
parameters.encodings = [{}];
}
parameters.encodings[0].maxBitrate = maxBitrate;
await sender.setParameters(parameters);
console.log(`Set max bitrate to ${maxBitrate} bps`);
}
// Usage
setMaxBitrate(peerConnection, 500000); // 500 kbps
Data Channels Use Cases
File Transfer
class FileTransfer {
constructor(dataChannel) {
this.channel = dataChannel;
this.chunkSize = 16384; // 16 KB chunks
}
async sendFile(file) {
const arrayBuffer = await file.arrayBuffer();
const totalChunks = Math.ceil(arrayBuffer.byteLength / this.chunkSize);
// Send metadata
this.channel.send(JSON.stringify({
type: 'file-start',
name: file.name,
size: file.size,
totalChunks: totalChunks
}));
// Send chunks
for (let i = 0; i < totalChunks; i++) {
const start = i * this.chunkSize;
const end = Math.min(start + this.chunkSize, arrayBuffer.byteLength);
const chunk = arrayBuffer.slice(start, end);
// Wait if buffer is filling up
while (this.channel.bufferedAmount > this.chunkSize * 10) {
await new Promise(resolve => setTimeout(resolve, 10));
}
this.channel.send(chunk);
// Progress update
const progress = ((i + 1) / totalChunks * 100).toFixed(1);
console.log(`Sending: ${progress}%`);
}
// Send completion
this.channel.send(JSON.stringify({ type: 'file-end' }));
}
receiveFile(onProgress, onComplete) {
const chunks = [];
let metadata = null;
this.channel.onmessage = (event) => {
if (typeof event.data === 'string') {
const message = JSON.parse(event.data);
if (message.type === 'file-start') {
metadata = message;
chunks.length = 0;
} else if (message.type === 'file-end') {
const blob = new Blob(chunks);
onComplete(blob, metadata);
}
} else {
// Binary chunk
chunks.push(event.data);
if (metadata) {
const progress = (chunks.length / metadata.totalChunks * 100)
.toFixed(1);
onProgress(progress);
}
}
};
}
}
// Usage
const fileTransfer = new FileTransfer(dataChannel);
// Sender
document.getElementById('fileInput').onchange = async (e) => {
const file = e.target.files[0];
await fileTransfer.sendFile(file);
};
// Receiver
fileTransfer.receiveFile(
(progress) => console.log(`Receiving: ${progress}%`),
(blob, metadata) => {
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = metadata.name;
a.click();
}
);
Gaming/Real-time Data
class GameDataChannel {
constructor(dataChannel) {
this.channel = dataChannel;
this.channel.binaryType = 'arraybuffer';
// Unreliable, unordered for low latency
this.channel = peerConnection.createDataChannel('game', {
ordered: false,
maxRetransmits: 0
});
}
sendPlayerPosition(x, y, angle) {
const buffer = new ArrayBuffer(12);
const view = new DataView(buffer);
view.setFloat32(0, x, true);
view.setFloat32(4, y, true);
view.setFloat32(8, angle, true);
this.channel.send(buffer);
}
onPlayerPosition(callback) {
this.channel.onmessage = (event) => {
const view = new DataView(event.data);
const x = view.getFloat32(0, true);
const y = view.getFloat32(4, true);
const angle = view.getFloat32(8, true);
callback(x, y, angle);
};
}
}
// Usage
const gameChannel = new GameDataChannel(dataChannel);
// Send position 60 times per second
setInterval(() => {
gameChannel.sendPlayerPosition(
player.x,
player.y,
player.angle
);
}, 1000 / 60);
gameChannel.onPlayerPosition((x, y, angle) => {
updateRemotePlayer(x, y, angle);
});
Security Considerations
Encryption
WebRTC Security Stack:
Application Data
↓
SRTP (Secure RTP)
- Encrypts media (audio/video)
- AES encryption
- HMAC authentication
↓
DTLS (Datagram TLS)
- Encrypts data channels
- Key exchange for SRTP
- Certificate verification
↓
UDP/TCP Transport
All WebRTC traffic is encrypted!
No option for unencrypted communication.
Certificate Verification
// Verify peer certificate fingerprint
peerConnection.onicecandidate = (event) => {
if (event.candidate === null) {
// Get local certificate
peerConnection.getConfiguration().certificates.forEach(cert => {
cert.getFingerprints().forEach(fingerprint => {
console.log('Local fingerprint:', fingerprint);
// Send to peer via secure signaling
// Peer should verify this matches SDP
});
});
}
};
// Check SDP fingerprint matches expected
function verifySdpFingerprint(sdp, expectedFingerprint) {
const fingerprintMatch = sdp.match(/a=fingerprint:(\S+) (\S+)/);
if (!fingerprintMatch) {
throw new Error('No fingerprint in SDP');
}
const [, algorithm, fingerprint] = fingerprintMatch;
if (fingerprint !== expectedFingerprint) {
throw new Error('Fingerprint mismatch! Possible MITM attack.');
}
return true;
}
Best Practices
1. Secure Signaling
- Use TLS/WSS for signaling
- Authenticate users
- Verify peer identity
2. Certificate Pinning
- Verify SDP fingerprints
- Out-of-band verification if possible
3. Access Control
- Verify room/session authorization
- Implement user authentication
- Rate limiting
4. Media Permissions
- Request minimal permissions
- Explain why access is needed
- Allow users to deny
5. Privacy
- Minimize data collection
- No recording without consent
- Clear privacy policy
6. Network Security
- Use TURN with authentication
- Restrict TURN access
- Monitor for abuse
Debugging and Troubleshooting
Enable Debug Logs
// Chrome: Enable WebRTC internals
// Navigate to: chrome://webrtc-internals
// Firefox: Enable logging
// Navigate to: about:webrtc
// Console logging
peerConnection.addEventListener('track', e => {
console.log('Track event:', e);
});
peerConnection.addEventListener('icecandidate', e => {
console.log('ICE candidate:', e.candidate);
});
peerConnection.addEventListener('icecandidateerror', e => {
console.error('ICE candidate error:', e);
});
peerConnection.addEventListener('connectionstatechange', e => {
console.log('Connection state:', peerConnection.connectionState);
});
peerConnection.addEventListener('iceconnectionstatechange', e => {
console.log('ICE connection state:',
peerConnection.iceConnectionState);
});
Get Detailed Statistics
async function getDetailedStats(peerConnection) {
const stats = await peerConnection.getStats();
const report = {};
stats.forEach(stat => {
if (stat.type === 'inbound-rtp' && stat.kind === 'video') {
report.video = {
bytesReceived: stat.bytesReceived,
packetsReceived: stat.packetsReceived,
packetsLost: stat.packetsLost,
jitter: stat.jitter,
frameWidth: stat.frameWidth,
frameHeight: stat.frameHeight,
framesPerSecond: stat.framesPerSecond,
framesDecoded: stat.framesDecoded,
framesDropped: stat.framesDropped
};
}
if (stat.type === 'inbound-rtp' && stat.kind === 'audio') {
report.audio = {
bytesReceived: stat.bytesReceived,
packetsReceived: stat.packetsReceived,
packetsLost: stat.packetsLost,
jitter: stat.jitter,
audioLevel: stat.audioLevel
};
}
if (stat.type === 'candidate-pair' && stat.state === 'succeeded') {
report.connection = {
localCandidateType: stat.localCandidateType,
remoteCandidateType: stat.remoteCandidateType,
currentRoundTripTime: stat.currentRoundTripTime,
availableOutgoingBitrate: stat.availableOutgoingBitrate,
bytesReceived: stat.bytesReceived,
bytesSent: stat.bytesSent
};
}
});
return report;
}
// Monitor every second
setInterval(async () => {
const stats = await getDetailedStats(peerConnection);
console.table(stats);
}, 1000);
Common Issues and Solutions
Issue: ICE connection fails
Solutions:
- Check STUN/TURN server configuration
- Verify firewall allows UDP traffic
- Add TURN server as fallback
- Check ICE candidate gathering
Issue: No video/audio
Solutions:
- Verify getUserMedia constraints
- Check browser permissions
- Verify tracks added to peer connection
- Check ontrack event handler
Issue: One-way audio/video
Solutions:
- Verify both peers add tracks
- Check SDP offer/answer exchange
- Verify both peers handle ontrack
- Check NAT/firewall rules
Issue: Poor quality
Solutions:
- Reduce resolution/bitrate
- Enable simulcast
- Check network bandwidth
- Monitor packet loss
- Verify codec support
Issue: High latency
Solutions:
- Use TURN server closer to users
- Enable unreliable data channels for gaming
- Reduce buffering
- Optimize codec settings
Browser Support
Desktop Browsers:
✓ Chrome 23+
✓ Firefox 22+
✓ Safari 11+
✓ Edge 79+ (Chromium-based)
✓ Opera 18+
Mobile Browsers:
✓ Chrome Android 28+
✓ Firefox Android 24+
✓ Safari iOS 11+
✓ Samsung Internet 4+
Feature Support:
- getUserMedia: All modern browsers
- RTCPeerConnection: All modern browsers
- RTCDataChannel: All modern browsers
- Screen sharing: Desktop only (most browsers)
- VP9 codec: Chrome, Firefox, Edge
- H.264 codec: All browsers (licensing)
Check: https://caniuse.com/rtcpeerconnection
Performance Optimization
Tips for Better Performance
// 1. Reuse peer connections
const peerConnections = new Map();
function getOrCreatePeerConnection(peerId) {
if (!peerConnections.has(peerId)) {
peerConnections.set(peerId, createPeerConnection());
}
return peerConnections.get(peerId);
}
// 2. Batch ICE candidates (trickle ICE)
const pendingCandidates = [];
peerConnection.onicecandidate = (event) => {
if (event.candidate) {
pendingCandidates.push(event.candidate);
// Send in batches
if (pendingCandidates.length >= 5) {
signaling.send({
type: 'ice-candidates',
candidates: pendingCandidates.splice(0)
});
}
}
};
// 3. Use efficient codecs
// VP9 or H.264 for video, Opus for audio
// 4. Enable hardware acceleration
// Automatic in most browsers
// 5. Limit resolution based on network
async function adaptToNetwork(peerConnection) {
const stats = await peerConnection.getStats();
// Analyze and adjust bitrate/resolution
}
// 6. Use object fit for video elements
<video style="object-fit: cover;" />
// 7. Clean up resources
function cleanup() {
localStream?.getTracks().forEach(track => track.stop());
peerConnection?.close();
dataChannel?.close();
}
ELI10: WebRTC Explained Simply
WebRTC lets browsers talk directly to each other without a server in the middle:
Traditional Communication
Your Browser → Server → Friend's Browser
- Everything goes through server
- Server sees all your data
- Costs more (server bandwidth)
- Higher latency
WebRTC Communication
Your Browser ←→ Friend's Browser
- Direct connection (peer-to-peer)
- Server only introduces you
- Private (server can't see)
- Faster (no middleman)
The Process
1. Get Permission
"Can I use your camera and microphone?"
2. Signaling (Meeting)
Server: "Hey Browser A, meet Browser B"
Exchange: "Here's how to reach me"
3. ICE/STUN (Finding the Path)
"What's my public address?"
"Can we connect directly?"
4. Connection!
Direct video/audio/data
Encrypted automatically
5. If Direct Fails
TURN server relays traffic
Still encrypted
Real-World Analogy
Traditional: Passing notes through teacher
WebRTC: Sitting next to friend and talking
Signaling: Teacher introduces you
STUN: Finding where each person sits
TURN: Using walkie-talkies if too far
Further Resources
Documentation
Tools
- chrome://webrtc-internals - Chrome debugging
- about:webrtc - Firefox debugging
- WebRTC Troubleshooter
Libraries
- SimpleWebRTC - Simplified WebRTC
- PeerJS - Easy peer-to-peer
- Janus Gateway - WebRTC server
- Kurento - Media server
Testing
Books
- Real-Time Communication with WebRTC by Salvatore Loreto
- WebRTC Cookbook by Andrii Sergiienko
- High Performance Browser Networking by Ilya Grigorik
Finance
Overview
Finance is the management of money, investments, and other financial instruments. This guide covers various aspects of financial markets, investment strategies, and trading concepts essential for understanding modern finance and making informed investment decisions.
What is Finance?
Finance encompasses the creation, management, and study of money, banking, credit, investments, assets, and liabilities. It involves:
- Personal Finance: Managing individual/household money
- Corporate Finance: Managing business finances
- Public Finance: Government revenue and expenditure
- Investment Finance: Growing wealth through financial instruments
Financial Markets
Market Types
- Stock Market: Equity securities (shares of companies)
- Bond Market: Debt securities (loans to companies/governments)
- Commodity Market: Physical goods (gold, oil, agricultural products)
- Forex Market: Currency exchange
- Derivatives Market: Contracts based on underlying assets
Market Participants
- Retail Investors: Individual investors
- Institutional Investors: Banks, hedge funds, pension funds
- Market Makers: Provide liquidity
- Brokers: Execute trades on behalf of clients
- Regulators: Ensure fair and orderly markets
Investment Instruments
1. Stocks (Equities)
Ownership shares in a company.
Types:
- Common Stock: Voting rights, dividends
- Preferred Stock: Fixed dividends, priority over common
Metrics:
- Price-to-Earnings (P/E) Ratio: Stock price / Earnings per share
- Dividend Yield: Annual dividend / Stock price
- Market Capitalization: Share price × Shares outstanding
# Calculate basic stock metrics
def calculate_pe_ratio(price, earnings_per_share):
"""Price-to-Earnings Ratio"""
return price / earnings_per_share
def calculate_dividend_yield(annual_dividend, stock_price):
"""Dividend Yield as percentage"""
return (annual_dividend / stock_price) * 100
def calculate_market_cap(price, shares_outstanding):
"""Market Capitalization"""
return price * shares_outstanding
# Example
stock_price = 150.00
eps = 10.00
annual_dividend = 3.00
shares = 1_000_000_000
pe_ratio = calculate_pe_ratio(stock_price, eps)
dividend_yield = calculate_dividend_yield(annual_dividend, stock_price)
market_cap = calculate_market_cap(stock_price, shares)
print(f"P/E Ratio: {pe_ratio:.2f}")
print(f"Dividend Yield: {dividend_yield:.2f}%")
print(f"Market Cap: ${market_cap:,.0f}")
See: Stocks Guide
2. Options
Contracts giving the right (not obligation) to buy/sell at a specific price.
Types:
- Call Option: Right to buy
- Put Option: Right to sell
Key Terms:
- Strike Price: Exercise price
- Premium: Option cost
- Expiration Date: Contract end date
- In-the-Money (ITM): Profitable to exercise
- Out-of-the-Money (OTM): Not profitable to exercise
- At-the-Money (ATM): Strike ≈ Current price
Greeks:
- Delta: Price sensitivity to underlying
- Gamma: Rate of delta change
- Theta: Time decay
- Vega: Volatility sensitivity
- Rho: Interest rate sensitivity
See: Options Trading
3. Futures
Obligatory contracts to buy/sell at a future date and price.
Characteristics:
- Standardized contracts
- Exchange-traded
- Margin requirements
- Daily settlement
Uses:
- Hedging risk
- Speculation
- Price discovery
Common Futures:
- Equity index futures (S&P 500, NASDAQ)
- Commodity futures (oil, gold, corn)
- Currency futures
- Interest rate futures
See: Futures Trading
4. Cryptocurrencies
Digital or virtual currencies using cryptography.
Popular Cryptocurrencies:
- Bitcoin (BTC): First cryptocurrency
- Ethereum (ETH): Smart contract platform
- Altcoins: Alternative cryptocurrencies
Key Concepts:
- Blockchain: Distributed ledger technology
- Mining: Transaction verification process
- Wallet: Storage for private keys
- Exchange: Platform for trading crypto
See: Cryptocurrency Guide
Investment Strategies
Value Investing
Buy undervalued securities based on fundamental analysis.
Key Principles:
- Focus on intrinsic value
- Margin of safety
- Long-term perspective
- Fundamental analysis
Metrics:
- P/E ratio
- Price-to-Book (P/B) ratio
- Debt-to-Equity ratio
- Free cash flow
Growth Investing
Invest in companies with high growth potential.
Characteristics:
- High P/E ratios
- Revenue growth
- Market expansion
- Innovation focus
Dividend Investing
Focus on stocks paying regular dividends.
Benefits:
- Steady income stream
- Lower volatility
- Compound growth
Metrics:
- Dividend yield
- Payout ratio
- Dividend growth rate
Index Investing
Track market indices through index funds/ETFs.
Advantages:
- Diversification
- Low fees
- Passive management
- Market returns
Popular Indices:
- S&P 500
- NASDAQ-100
- Dow Jones Industrial Average
- Russell 2000
Analysis Methods
Fundamental Analysis
Evaluate intrinsic value through financial statements.
Financial Statements:
- Income Statement: Revenue, expenses, profit
- Balance Sheet: Assets, liabilities, equity
- Cash Flow Statement: Operating, investing, financing cash flows
Key Ratios:
# Profitability Ratios
def gross_margin(revenue, cogs):
return ((revenue - cogs) / revenue) * 100
def net_profit_margin(net_income, revenue):
return (net_income / revenue) * 100
def return_on_equity(net_income, shareholders_equity):
return (net_income / shareholders_equity) * 100
# Liquidity Ratios
def current_ratio(current_assets, current_liabilities):
return current_assets / current_liabilities
def quick_ratio(current_assets, inventory, current_liabilities):
return (current_assets - inventory) / current_liabilities
# Leverage Ratios
def debt_to_equity(total_debt, total_equity):
return total_debt / total_equity
def interest_coverage(ebit, interest_expense):
return ebit / interest_expense
# Efficiency Ratios
def asset_turnover(revenue, total_assets):
return revenue / total_assets
def inventory_turnover(cogs, average_inventory):
return cogs / average_inventory
See: Fundamental Analysis
Technical Analysis
Analyze price patterns and trends using charts.
Common Indicators:
- Moving Averages: Simple (SMA), Exponential (EMA)
- RSI: Relative Strength Index (overbought/oversold)
- MACD: Moving Average Convergence Divergence
- Bollinger Bands: Volatility indicator
- Volume: Trading activity
import pandas as pd
import numpy as np
def simple_moving_average(prices, period):
"""Calculate SMA"""
return prices.rolling(window=period).mean()
def exponential_moving_average(prices, period):
"""Calculate EMA"""
return prices.ewm(span=period, adjust=False).mean()
def relative_strength_index(prices, period=14):
"""Calculate RSI"""
delta = prices.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
rs = gain / loss
rsi = 100 - (100 / (1 + rs))
return rsi
def bollinger_bands(prices, period=20, std_dev=2):
"""Calculate Bollinger Bands"""
sma = simple_moving_average(prices, period)
std = prices.rolling(window=period).std()
upper_band = sma + (std * std_dev)
lower_band = sma - (std * std_dev)
return upper_band, sma, lower_band
def macd(prices, fast=12, slow=26, signal=9):
"""Calculate MACD"""
ema_fast = exponential_moving_average(prices, fast)
ema_slow = exponential_moving_average(prices, slow)
macd_line = ema_fast - ema_slow
signal_line = macd_line.ewm(span=signal, adjust=False).mean()
histogram = macd_line - signal_line
return macd_line, signal_line, histogram
See: Technical Analysis
Risk Management
Portfolio Diversification
Don't put all eggs in one basket.
Diversification Strategies:
- Across asset classes (stocks, bonds, real estate)
- Across sectors (tech, healthcare, finance)
- Across geographies (domestic, international)
- Across market caps (large, mid, small)
Position Sizing
Determine how much to invest in each position.
def position_size_fixed_dollar(account_balance, risk_per_trade):
"""Fixed dollar amount per trade"""
return risk_per_trade
def position_size_percentage(account_balance, risk_percentage):
"""Percentage of account balance"""
return account_balance * (risk_percentage / 100)
def position_size_volatility(account_balance, risk_percentage, entry_price, stop_loss):
"""Based on volatility and stop loss"""
risk_per_share = abs(entry_price - stop_loss)
total_risk = account_balance * (risk_percentage / 100)
shares = total_risk / risk_per_share
return int(shares)
# Example
account = 100000
risk_pct = 2 # 2% risk per trade
entry = 150
stop = 145
shares = position_size_volatility(account, risk_pct, entry, stop)
print(f"Buy {shares} shares at ${entry} with stop at ${stop}")
Stop Loss Orders
Automatic sell orders to limit losses.
Types:
- Fixed Stop: Specific price level
- Trailing Stop: Adjusts with price movement
- Percentage Stop: Based on percentage decline
def calculate_stop_loss(entry_price, stop_percentage, position_type='long'):
"""Calculate stop loss price"""
if position_type == 'long':
return entry_price * (1 - stop_percentage / 100)
else: # short
return entry_price * (1 + stop_percentage / 100)
def calculate_take_profit(entry_price, risk_reward_ratio, stop_loss, position_type='long'):
"""Calculate take profit based on risk/reward ratio"""
risk = abs(entry_price - stop_loss)
reward = risk * risk_reward_ratio
if position_type == 'long':
return entry_price + reward
else: # short
return entry_price - reward
# Example: 2:1 risk/reward ratio
entry = 100
stop_pct = 5
rr_ratio = 2
stop = calculate_stop_loss(entry, stop_pct, 'long')
target = calculate_take_profit(entry, rr_ratio, stop, 'long')
print(f"Entry: ${entry}")
print(f"Stop Loss: ${stop:.2f}")
print(f"Take Profit: ${target:.2f}")
print(f"Risk: ${entry - stop:.2f}")
print(f"Reward: ${target - entry:.2f}")
Performance Metrics
Returns
def simple_return(start_value, end_value):
"""Simple return percentage"""
return ((end_value - start_value) / start_value) * 100
def compound_annual_growth_rate(start_value, end_value, years):
"""CAGR"""
return (((end_value / start_value) ** (1 / years)) - 1) * 100
def total_return(initial_investment, final_value, dividends):
"""Total return including dividends"""
return ((final_value + dividends - initial_investment) / initial_investment) * 100
Risk Metrics
import numpy as np
def volatility(returns):
"""Standard deviation of returns (annualized)"""
return np.std(returns) * np.sqrt(252) # 252 trading days
def sharpe_ratio(returns, risk_free_rate=0.02):
"""Risk-adjusted return"""
excess_returns = returns - risk_free_rate / 252
return np.mean(excess_returns) / np.std(excess_returns) * np.sqrt(252)
def maximum_drawdown(prices):
"""Maximum peak-to-trough decline"""
cummax = np.maximum.accumulate(prices)
drawdown = (prices - cummax) / cummax
return np.min(drawdown) * 100
def beta(asset_returns, market_returns):
"""Measure of volatility relative to market"""
covariance = np.cov(asset_returns, market_returns)[0][1]
market_variance = np.var(market_returns)
return covariance / market_variance
Trading Psychology
Emotional Discipline
Common Pitfalls:
- Fear of Missing Out (FOMO): Chasing rallies
- Loss Aversion: Holding losers too long
- Overconfidence: Taking excessive risk
- Confirmation Bias: Seeking supporting evidence only
- Anchoring: Fixating on specific price points
Best Practices:
- Follow your trading plan
- Keep emotions in check
- Accept losses gracefully
- Don't overtrade
- Take breaks when needed
Trading Plan
Essential components:
- Entry Criteria: When to buy
- Exit Criteria: When to sell (profit and loss)
- Position Sizing: How much to invest
- Risk Management: Stop loss levels
- Record Keeping: Track all trades
Financial Calculations
Time Value of Money
def future_value(present_value, rate, periods):
"""FV = PV × (1 + r)^n"""
return present_value * (1 + rate) ** periods
def present_value(future_value, rate, periods):
"""PV = FV / (1 + r)^n"""
return future_value / (1 + rate) ** periods
def compound_interest(principal, rate, periods, compounds_per_period=1):
"""Compound interest formula"""
return principal * (1 + rate / compounds_per_period) ** (periods * compounds_per_period)
# Example: $10,000 invested for 10 years at 7% annual return
principal = 10000
rate = 0.07
years = 10
fv = future_value(principal, rate, years)
print(f"${principal:,.2f} grows to ${fv:,.2f} in {years} years")
Annuities
def future_value_annuity(payment, rate, periods):
"""FV of regular payments"""
return payment * (((1 + rate) ** periods - 1) / rate)
def present_value_annuity(payment, rate, periods):
"""PV of regular payments"""
return payment * ((1 - (1 + rate) ** -periods) / rate)
def loan_payment(principal, rate, periods):
"""Calculate loan payment"""
return principal * (rate * (1 + rate) ** periods) / ((1 + rate) ** periods - 1)
# Example: Mortgage calculation
loan_amount = 300000
annual_rate = 0.04
years = 30
monthly_rate = annual_rate / 12
months = years * 12
monthly_payment = loan_payment(loan_amount, monthly_rate, months)
total_paid = monthly_payment * months
total_interest = total_paid - loan_amount
print(f"Monthly Payment: ${monthly_payment:,.2f}")
print(f"Total Interest: ${total_interest:,.2f}")
Investment Accounts
Account Types
Taxable Accounts:
- Individual brokerage accounts
- Joint accounts
- Margin accounts
Tax-Advantaged (US):
- 401(k): Employer-sponsored retirement
- IRA: Individual Retirement Account
- Roth IRA: Tax-free growth and withdrawals
- HSA: Health Savings Account
Fees and Costs
- Expense Ratios: Mutual fund/ETF annual fees
- Trading Commissions: Per-trade fees
- Management Fees: Advisory fees
- Tax Implications: Capital gains, dividends
Market Orders
Order Types
# Common order types
class Order:
"""Order examples"""
@staticmethod
def market_order():
"""Execute at current market price"""
return {
"type": "market",
"execution": "immediate",
"price": "current market price"
}
@staticmethod
def limit_order(limit_price):
"""Execute at specific price or better"""
return {
"type": "limit",
"limit_price": limit_price,
"execution": "when price reaches limit"
}
@staticmethod
def stop_loss_order(stop_price):
"""Sell when price falls to stop level"""
return {
"type": "stop_loss",
"stop_price": stop_price,
"execution": "when price hits stop"
}
@staticmethod
def stop_limit_order(stop_price, limit_price):
"""Combines stop and limit orders"""
return {
"type": "stop_limit",
"stop_price": stop_price,
"limit_price": limit_price,
"execution": "limit order triggered at stop"
}
Order Duration
- Day Order: Expires at end of trading day
- Good Till Canceled (GTC): Active until executed or canceled
- Fill or Kill (FOK): Execute immediately in full or cancel
- Immediate or Cancel (IOC): Execute immediately, cancel remainder
Resources and Tools
Financial Data Sources
- Yahoo Finance
- Bloomberg Terminal
- TradingView
- Alpha Vantage API
- IEX Cloud
Analysis Tools
- Excel/Google Sheets
- Python (pandas, numpy, matplotlib)
- TradingView
- ThinkOrSwim
- MetaTrader
Educational Resources
- Investopedia
- Khan Academy (Finance)
- CFA Institute
- Financial news (WSJ, FT, Bloomberg)
Available Guides
Explore detailed guides for specific topics:
- General Finance - Fundamental concepts and principles
- Stocks - Equity investing and analysis
- Options - Options trading strategies
- Futures - Futures contracts and trading
- Cryptocurrency - Digital assets and blockchain
- Fundamental Analysis - Company valuation
- Technical Analysis - Chart patterns and indicators
Important Disclaimers
- Not Financial Advice: This is educational content only
- Do Your Own Research: Always verify information
- Risk Warning: Investing involves risk of loss
- Past Performance: Does not guarantee future results
- Diversification: Does not ensure profit or protect against loss
- Consult Professionals: Consider seeking professional advice
Key Principles
Investment Principles
- Start Early: Compound interest is powerful
- Diversify: Spread risk across assets
- Invest Regularly: Dollar-cost averaging
- Control Costs: Minimize fees and taxes
- Stay Disciplined: Stick to your plan
- Educate Yourself: Continuous learning
- Manage Risk: Protect your capital
- Think Long-Term: Avoid emotional decisions
Risk Management Rules
- Never risk more than you can afford to lose
- Use stop losses to limit downside
- Diversify across multiple positions
- Size positions appropriately
- Have a clear exit strategy
- Don't let winners become losers
- Cut losses quickly
- Let profits run (within reason)
Next Steps
- Learn the basics of General Finance
- Study Fundamental Analysis
- Explore Technical Analysis
- Understand Stocks and equity markets
- Learn about Options for hedging and income
- Research Cryptocurrency opportunities
- Practice with paper trading before using real money
- Build a diversified portfolio aligned with your goals
- Continuously educate yourself
- Start investing with money you can afford to lose
Remember: Successful investing requires knowledge, discipline, and patience. Take time to learn, practice with small amounts, and gradually build your skills and portfolio.
General
Sharpe Ratio
The Sharpe Ratio is a widely used metric in finance to evaluate the performance of an investment by measuring the excess return per unit of risk. It is calculated by dividing the difference between the return of the investment and the risk-free rate by the standard deviation of the investment's returns.
$$ SR = \frac{R_p - R_f}{\sigma_p} $$
Where:
- \( R_p \) is the return of the portfolio
- ( R_f ) is the risk-free rate (usually the return of a benchmark like the S&P 500)
- ( \sigma_p ) is the standard deviation of the portfolio's returns
Calculating Standard Deviation of Returns
The standard deviation of returns is a measure of the dispersion or variability of investment returns over a period of time. It helps in understanding the risk associated with the investment. Here is a step-by-step process to calculate the standard deviation of returns:
-
Collect the Returns Data: Gather the periodic returns of the investment. These returns can be daily, monthly, or yearly.
-
Calculate the Mean Return: Compute the average return over the period.
$$ \bar{R} = \frac{\sum_{i=1}^{n} R_i}{n} $$
Where:
- ( \bar{R} ) is the mean return
- ( R_i ) is the return for period ( i )
- ( n ) is the number of periods
- Compute the Variance: Calculate the variance by finding the average of the squared differences between each return and the mean return.
$$ \sigma^2 = \frac{\sum_{i=1}^{n} (R_i - \bar{R})^2}{n} $$
Where:
- ( \sigma^2 ) is the variance
- Calculate the Standard Deviation: Take the square root of the variance to get the standard deviation.
$$ \sigma = \sqrt{\sigma^2} $$
Where:
- ( \sigma ) is the standard deviation
Sample Calculation
Assume the following monthly returns for an investment over 5 months: 2%, 3%, -1%, 4%, and 5%.
- Mean Return:
$$ \bar{R} = \frac{2 + 3 - 1 + 4 + 5}{5} = \frac{13}{5} = 2.6% $$
- Variance:
$$ \sigma^2 = \frac{(2 - 2.6)^2 + (3 - 2.6)^2 + (-1 - 2.6)^2 + (4 - 2.6)^2 + (5 - 2.6)^2}{5} $$
$$ \sigma^2 = \frac{(-0.6)^2 + (0.4)^2 + (-3.6)^2 + (1.4)^2 + (2.4)^2}{5} $$
$$ \sigma^2 = \frac{0.36 + 0.16 + 12.96 + 1.96 + 5.76}{5} = \frac{21.2}{5} = 4.24 $$
- Standard Deviation:
$$ \sigma = \sqrt{4.24} \approx 2.06% $$
In this example, the standard deviation of the returns is approximately 2.06%, indicating the variability of the investment returns over the period.
Sample Scenario
To better understand the Sharpe Ratio, let's consider a practical example.
Assume the following data for a portfolio:
- Portfolio return (( R_p )): 12% or 0.12
- Risk-free rate (( R_f )): 2% or 0.02
- Portfolio standard deviation (( \sigma_p )): 8% or 0.08
Using the Sharpe Ratio formula:
$$ SR = \frac{R_p - R_f}{\sigma_p} $$
Substituting the values:
$$ SR = \frac{0.12 - 0.02}{0.08} = \frac{0.10}{0.08} = 1.25 $$
In this scenario, the Sharpe Ratio is 1.25, indicating that the portfolio generates 1.25 units of excess return for each unit of risk taken.
Kelly Criterion
The Kelly Criterion is a formula used to determine the optimal size of a series of bets. It calculates the ratio of edge over odds, helping to maximize the growth of capital over time. The formula is expressed as (k), where (p) and (q) are the probabilities of winning and losing, respectively.
$$ k = \frac{p - q}{o} $$
Where:
- (p) is the probability of winning
- (q) is the probability of losing
- (o) is the odds of the bet
Sample Scenario
Consider a scenario to illustrate the Kelly Criterion.
Assume the following data for a bet:
- Probability of winning (( p )): 60% or 0.60
- Probability of losing (( q )): 40% or 0.40
- Odds of the bet (( o )): 2:1
Understanding Odds of a Bet
The odds of a bet represent the ratio of the probability of winning to the probability of losing. They are a crucial component in betting strategies, including the Kelly Criterion. Odds can be expressed in different formats, such as fractional, decimal, and moneyline.
Fractional Odds
Fractional odds are commonly used in the UK and are represented as a fraction (e.g., 2/1). The numerator (first number) represents the potential profit, while the denominator (second number) represents the stake. For example, 2/1 odds mean you win $2 for every $1 bet.
Decimal Odds
Decimal odds are popular in Europe and Australia. They are represented as a decimal number (e.g., 3.00). The decimal number includes the original stake, so the total payout is calculated by multiplying the stake by the decimal odds. For example, 3.00 odds mean a $1 bet returns $3 (including the $1 stake).
Moneyline Odds
Moneyline odds are commonly used in the United States and can be positive or negative. Positive moneyline odds (e.g., +200) indicate how much profit you make on a $100 bet. Negative moneyline odds (e.g., -150) indicate how much you need to bet to win $100.
Calculating Odds
To calculate the odds of a bet, you need to know the probabilities of winning and losing. The formula for calculating fractional odds is:
$$ \text{Odds} = \frac{p}{q} $$
Where:
- ( p ) is the probability of winning
- ( q ) is the probability of losing
For example, if the probability of winning is 60% (0.60) and the probability of losing is 40% (0.40), the fractional odds are:
$$ \text{Odds} = \frac{0.60}{0.40} = \frac{3}{2} = 1.5 $$
To convert fractional odds to decimal odds, add 1 to the fractional odds:
$$ \text{Decimal Odds} = \text{Fractional Odds} + 1 $$
Using the previous example:
$$ \text{Decimal Odds} = 1.5 + 1 = 2.5 $$
To convert fractional odds to moneyline odds:
- If the fractional odds are greater than 1 (e.g., 2/1), the moneyline odds are positive: ( \text{Moneyline Odds} = \text{Fractional Odds} \times 100 )
- If the fractional odds are less than 1 (e.g., 1/2), the moneyline odds are negative: ( \text{Moneyline Odds} = -\left(\frac{100}{\text{Fractional Odds}}\right) )
Using the previous example (1.5 fractional odds):
$$ \text{Moneyline Odds} = 1.5 \times 100 = +150 $$
Understanding and calculating the odds of a bet is essential for making informed betting decisions and optimizing strategies like the Kelly Criterion.
Using the Kelly Criterion formula:
$$ k = \frac{p - q}{o} $$
Substituting the values:
$$ k = \frac{0.60 - 0.40}{2} = \frac{0.20}{2} = 0.10 $$
In this scenario, the Kelly Criterion suggests betting 10% of your bankroll. For example, with a $1000 bankroll, you should bet $100.
Intuition of the Kelly Criterion
The Kelly Criterion is a mathematical formula used to determine the optimal size of a series of bets to maximize the logarithm of wealth over time. It is particularly useful in scenarios where the goal is to grow wealth exponentially while managing risk. The intuition behind the Kelly Criterion can be broken down into several key concepts:
Key Concepts
-
Maximizing Growth: The Kelly Criterion aims to maximize the long-term growth rate of your bankroll. By betting a fraction of your bankroll that is proportional to the edge you have over the odds, you can achieve exponential growth over time.
-
Balancing Risk and Reward: The formula balances the potential reward of a bet with the risk of losing. By betting too much, you risk significant losses that can deplete your bankroll. By betting too little, you miss out on potential gains. The Kelly Criterion finds the optimal balance.
-
Proportional Betting: The Kelly Criterion suggests betting a fraction of your bankroll that is proportional to your edge. This means that as your edge increases, the fraction of your bankroll you should bet also increases. Conversely, if your edge decreases, you should bet a smaller fraction.
-
Logarithmic Utility: The Kelly Criterion is based on the concept of logarithmic utility, which means that the utility (or satisfaction) derived from wealth increases logarithmically. This approach ensures that the strategy is focused on long-term growth rather than short-term gains.
Example Scenario
Consider a scenario where you have a 60% chance of winning a bet (probability ( p = 0.60 )) and a 40% chance of losing (probability ( q = 0.40 )). The odds offered are 2:1 (decimal odds of 2.0).
Using the Kelly Criterion formula:
$$ k = \frac{p - q}{o} $$
Substituting the values:
$$ k = \frac{0.60 - 0.40}{2} = \frac{0.20}{2} = 0.10 $$
In this scenario, the Kelly Criterion suggests betting 10% of your bankroll. For example, with a $1000 bankroll, you should bet $100.
Advantages of the Kelly Criterion
- Optimal Growth: The Kelly Criterion maximizes the long-term growth rate of your bankroll, ensuring that you achieve exponential growth over time.
- Risk Management: By betting a fraction of your bankroll, the Kelly Criterion helps manage risk and prevent significant losses.
- Adaptability: The formula adjusts the bet size based on the edge, allowing for flexible and adaptive betting strategies.
Conclusion
The Kelly Criterion is a powerful tool for optimizing bet sizes and maximizing long-term growth. By balancing risk and reward and focusing on proportional betting, the Kelly Criterion provides a strategic approach to betting that can lead to exponential wealth growth over time. Understanding the intuition behind the Kelly Criterion can help you make more informed and strategic betting decisions.
Pot Geometry
Pot Geometry is a strategic betting approach where a consistent fraction of the pot is wagered on each round. Also known as geometric bet sizing, this strategy aims to maximize the amount of money an opponent contributes to the pot.
Detailed Explanation of Pot Geometry
Pot Geometry is particularly useful in games like poker, where managing the pot size and betting strategically can significantly impact outcomes. By betting a fixed fraction of the pot on each round, the pot size grows exponentially, maximizing potential winnings while managing risk.
Key Concepts
-
Fractional Betting: A fixed fraction of the current pot size is bet on each round. For instance, if the fraction is 50%, then 50% of the current pot size is added to the pot each round.
-
Exponential Growth: Consistent fractional betting leads to exponential growth of the pot size, potentially increasing winnings over multiple rounds.
-
Risk Management: Pot Geometry ensures bets are proportional to the current pot size, preventing over-betting and large losses.
Example Scenario
Consider an example to demonstrate Pot Geometry:
- Initial pot size: $100
- Fraction of pot to bet: 50% (0.50)
Round 1:
- Current pot size: $100
- Bet size: 50% of $100 = $50
- New pot size: $100 + $50 = $150
Round 2:
- Current pot size: $150
- Bet size: 50% of $150 = $75
- New pot size: $150 + $75 = $225
Round 3:
- Current pot size: $225
- Bet size: 50% of $225 = $112.50
- New pot size: $225 + $112.50 = $337.50
As shown, the pot size grows exponentially with each betting round.
Advantages of Pot Geometry
- Consistent Growth: The pot grows steadily, allowing for potentially higher winnings over multiple rounds.
- Controlled Risk: Betting a fraction of the pot controls risk, keeping it proportional to the current pot size.
- Strategic Flexibility: Players can adjust the betting fraction based on confidence and game dynamics.
Conclusion
Pot Geometry is a powerful betting strategy that leverages exponential growth and risk management principles. By consistently betting a fraction of the pot, players can maximize potential winnings while maintaining controlled risk. This strategy is particularly effective in poker, where strategic pot management can significantly influence long-term success.
Technical Analysis
Stock Technical Analysis
Stock technical analysis is a method used to evaluate and predict the future price movements of stocks by analyzing historical price data, trading volume, and other market indicators. Unlike fundamental analysis, which focuses on a company's financial health and intrinsic value, technical analysis relies on chart patterns, technical indicators, and statistical measures to make trading decisions.
Key Concepts
-
Price Trends: Technical analysts study price trends to identify the direction in which a stock's price is moving. Trends can be upward (bullish), downward (bearish), or sideways (neutral). Recognizing trends helps traders make informed decisions about when to buy or sell stocks.
-
Support and Resistance Levels: Support levels are price points where a stock tends to find buying interest, preventing it from falling further. Resistance levels are price points where selling interest is strong enough to prevent the stock from rising further. Identifying these levels helps traders set entry and exit points.
-
Chart Patterns: Chart patterns are visual formations created by the price movements of a stock. Common patterns include head and shoulders, double tops and bottoms, triangles, and flags. These patterns can signal potential reversals or continuations in price trends.
-
Technical Indicators: Technical indicators are mathematical calculations based on price, volume, or open interest data. Popular indicators include moving averages, relative strength index (RSI), moving average convergence divergence (MACD), and Bollinger Bands. These indicators help traders identify overbought or oversold conditions, trend strength, and potential reversal points.
-
Volume Analysis: Trading volume is the number of shares traded during a specific period. Analyzing volume helps confirm the strength of price movements. For example, a price increase accompanied by high volume suggests strong buying interest, while a price increase with low volume may indicate weak buying interest.
Example Scenario
Consider a stock that has been in an upward trend for several months. A technical analyst might use the following steps to evaluate the stock:
-
Identify the Trend: The analyst observes that the stock is in a bullish trend, with higher highs and higher lows on the price chart.
-
Determine Support and Resistance Levels: The analyst identifies key support levels at $50 and $55, and resistance levels at $65 and $70.
-
Analyze Chart Patterns: The analyst notices a bullish flag pattern forming, indicating a potential continuation of the upward trend.
-
Use Technical Indicators: The analyst checks the RSI, which shows the stock is not yet overbought, and the MACD, which indicates strong bullish momentum.
-
Examine Volume: The analyst observes that recent price increases are accompanied by high trading volume, confirming strong buying interest.
Based on this analysis, the technical analyst might decide to buy the stock, anticipating further price increases.
Advantages of Technical Analysis
- Timely Decision-Making: Technical analysis provides real-time data and signals, allowing traders to make quick and informed decisions.
- Market Sentiment Insight: By analyzing price and volume data, technical analysis helps traders gauge market sentiment and investor behavior.
- Versatility: Technical analysis can be applied to various financial instruments, including stocks, options, futures, and cryptocurrencies.
Conclusion
Stock technical analysis is a valuable tool for traders and investors seeking to predict future price movements and make informed trading decisions. By understanding key concepts such as price trends, support and resistance levels, chart patterns, technical indicators, and volume analysis, traders can develop effective strategies to navigate the stock market. While technical analysis has its limitations, it remains a popular and widely used method for analyzing and trading stocks.
Moving Averages
Moving averages are one of the most commonly used technical indicators in stock analysis. They smooth out price data to identify the direction of the trend over a specific period. There are two main types of moving averages:
-
Simple Moving Average (SMA): The SMA is calculated by taking the average of a stock's price over a specific number of periods. For example, a 10-day SMA is the average of the closing prices of the last 10 days.
-
Exponential Moving Average (EMA): The EMA gives more weight to recent prices, making it more responsive to new information. It is calculated using a formula that applies a weighting factor to the most recent price data.
Example Scenario
Consider a stock with the following closing prices over 5 days: $10, $12, $14, $16, and $18.
- The 5-day SMA would be: (10 + 12 + 14 + 16 + 18) / 5 = $14.
- The 5-day EMA would place more weight on the recent prices, resulting in a value closer to the latest price of $18.
Advantages of Moving Averages
- Trend Identification: Moving averages help identify the direction of the trend, making it easier for traders to follow the market's momentum.
- Support and Resistance Levels: Moving averages can act as dynamic support and resistance levels, providing entry and exit points for trades.
Relative Strength Index (RSI)
The Relative Strength Index (RSI) is a momentum oscillator that measures the speed and change of price movements. It ranges from 0 to 100 and is used to identify overbought or oversold conditions in a stock.
- Overbought: An RSI above 70 suggests that a stock may be overbought and due for a correction.
- Oversold: An RSI below 30 indicates that a stock may be oversold and could be due for a rebound.
Example Scenario
Consider a stock with an RSI of 75. This high RSI value suggests that the stock is overbought, and a trader might consider selling or shorting the stock in anticipation of a price correction.
Advantages of RSI
- Momentum Measurement: RSI helps measure the strength of a stock's price movement, providing insights into potential reversals.
- Overbought/Oversold Signals: RSI provides clear signals for overbought and oversold conditions, aiding in decision-making.
Moving Average Convergence Divergence (MACD)
The Moving Average Convergence Divergence (MACD) is a trend-following momentum indicator that shows the relationship between two moving averages of a stock's price. It consists of three components:
- MACD Line: The difference between the 12-day EMA and the 26-day EMA.
- Signal Line: A 9-day EMA of the MACD line.
- Histogram: The difference between the MACD line and the signal line.
Example Scenario
Consider a stock where the MACD line crosses above the signal line. This bullish crossover indicates a potential buy signal, suggesting that the stock's price may rise.
Advantages of MACD
- Trend and Momentum: MACD combines trend and momentum analysis, providing a comprehensive view of the stock's price action.
- Crossover Signals: MACD crossovers generate buy and sell signals, aiding in timing trades.
Bollinger Bands
Bollinger Bands are a volatility indicator that consists of three lines: the middle band (SMA), the upper band, and the lower band. The upper and lower bands are typically set two standard deviations away from the middle band.
- Upper Band: Indicates overbought conditions when the price touches or exceeds it.
- Lower Band: Indicates oversold conditions when the price touches or falls below it.
Example Scenario
Consider a stock trading near the upper Bollinger Band. This suggests that the stock may be overbought, and a trader might consider selling or shorting the stock.
Advantages of Bollinger Bands
- Volatility Measurement: Bollinger Bands adjust to market volatility, providing dynamic support and resistance levels.
- Overbought/Oversold Conditions: Bollinger Bands help identify overbought and oversold conditions, aiding in decision-making.
By understanding and utilizing these technical indicators—moving averages, RSI, MACD, and Bollinger Bands—traders can develop more informed and effective trading strategies to navigate the stock market.
Chart Patterns
Chart patterns are formations created by the price movements of a stock or other financial instrument on a chart. These patterns are used by technical analysts to predict future price movements based on historical data. Chart patterns can be classified into two main categories: continuation patterns and reversal patterns.
Continuation Patterns
Continuation patterns indicate that the current trend is likely to continue after the pattern is completed. Some common continuation patterns include:
-
Triangles: Triangles are formed by converging trendlines that represent a period of consolidation before the price breaks out in the direction of the existing trend. There are three types of triangles:
- Ascending Triangle: Characterized by a flat upper trendline and a rising lower trendline, indicating a potential bullish breakout.
- Descending Triangle: Characterized by a flat lower trendline and a descending upper trendline, indicating a potential bearish breakout.
- Symmetrical Triangle: Formed by converging upper and lower trendlines, indicating a potential breakout in either direction.
-
Flags and Pennants: Flags and pennants are short-term continuation patterns that represent brief periods of consolidation before the price resumes its previous trend.
- Flag: A rectangular pattern that slopes against the prevailing trend, indicating a brief consolidation before the trend continues.
- Pennant: A small symmetrical triangle that forms after a strong price movement, indicating a brief consolidation before the trend continues.
-
Rectangles: Rectangles are formed by horizontal support and resistance levels, indicating a period of consolidation before the price breaks out in the direction of the existing trend.
Reversal Patterns
Reversal patterns indicate that the current trend is likely to reverse after the pattern is completed. Some common reversal patterns include:
-
Head and Shoulders: The head and shoulders pattern is a bearish reversal pattern that consists of three peaks: a higher peak (head) between two lower peaks (shoulders). The pattern is confirmed when the price breaks below the neckline, indicating a potential trend reversal.
-
Inverse Head and Shoulders: The inverse head and shoulders pattern is a bullish reversal pattern that consists of three troughs: a lower trough (head) between two higher troughs (shoulders). The pattern is confirmed when the price breaks above the neckline, indicating a potential trend reversal.
-
Double Top and Double Bottom: The double top is a bearish reversal pattern that consists of two peaks at approximately the same price level, indicating a potential trend reversal when the price breaks below the support level. The double bottom is a bullish reversal pattern that consists of two troughs at approximately the same price level, indicating a potential trend reversal when the price breaks above the resistance level.
-
Triple Top and Triple Bottom: The triple top is a bearish reversal pattern that consists of three peaks at approximately the same price level, indicating a potential trend reversal when the price breaks below the support level. The triple bottom is a bullish reversal pattern that consists of three troughs at approximately the same price level, indicating a potential trend reversal when the price breaks above the resistance level.
Example Scenario
Consider a stock that forms an ascending triangle pattern. The stock's price has been rising, and the pattern is characterized by a flat upper trendline and a rising lower trendline. This suggests that the stock is likely to break out to the upside, continuing its upward trend.
Advantages of Chart Patterns
- Predictive Power: Chart patterns provide insights into potential future price movements based on historical data.
- Visual Representation: Chart patterns offer a visual representation of market psychology and investor behavior.
- Versatility: Chart patterns can be applied to various financial instruments and timeframes, making them a versatile tool for technical analysis.
By understanding and utilizing chart patterns, traders can enhance their ability to predict future price movements and make more informed trading decisions. Combining chart patterns with other technical indicators can further improve the accuracy of trading strategies.
How to Find Support Levels
Support levels are price levels at which a stock or other financial instrument tends to find buying interest, preventing the price from falling further. Identifying support levels is crucial for traders as it helps them make informed decisions about entry and exit points. Here are some methods to find support levels:
Methods to Identify Support Levels
-
Historical Price Levels: Look for price levels where the stock has previously found support. These levels can be identified by examining past price charts and noting where the price has repeatedly bounced back up.
-
Moving Averages: Moving averages, such as the 50-day or 200-day moving average, can act as dynamic support levels. When the price approaches these moving averages, it often finds support and reverses direction.
-
Trendlines: Draw trendlines by connecting the lows of an uptrend. These trendlines can act as support levels, indicating where the price is likely to find buying interest.
-
Fibonacci Retracement Levels: Use Fibonacci retracement levels to identify potential support levels. Common retracement levels include 38.2%, 50%, and 61.8%. These levels are based on the Fibonacci sequence and can indicate where the price may find support during a pullback.
-
Volume Profile: Analyze the volume profile to identify price levels with high trading activity. These levels often act as support, as they represent areas where a significant number of buyers have previously entered the market.
-
Psychological Levels: Round numbers, such as $50, $100, or $1000, often act as psychological support levels. Traders tend to place buy orders at these levels, creating support.
Example Scenario
Consider a stock that has been in an uptrend and is currently trading at $150. By examining the historical price chart, you notice that the stock has previously found support at $140. Additionally, the 50-day moving average is currently at $140, reinforcing this level as a potential support. You also draw a trendline connecting the recent lows, which intersects at $140. Based on this analysis, you identify $140 as a strong support level for the stock.
Advantages of Identifying Support Levels
- Informed Decision-Making: Knowing support levels helps traders make informed decisions about when to enter or exit a trade.
- Risk Management: Identifying support levels allows traders to set stop-loss orders below these levels, managing risk and minimizing potential losses.
- Improved Timing: By recognizing support levels, traders can improve their timing for entering trades, increasing the likelihood of profitable outcomes.
By understanding and utilizing support levels, traders can enhance their ability to predict price movements and make more strategic trading decisions. Combining support levels with other technical indicators and chart patterns can further improve the accuracy of trading strategies.
Best Brokers for Futures Trading
Choosing the right broker is crucial for successful futures trading. Here are some of the best brokers for futures trading, known for their robust platforms, competitive fees, and excellent customer support:
-
TD Ameritrade: TD Ameritrade offers a powerful trading platform called thinkorswim, which is highly regarded for its advanced charting tools, technical analysis features, and real-time data. They provide competitive commission rates and a wide range of futures products.
-
Interactive Brokers: Interactive Brokers is known for its low-cost trading and extensive range of futures contracts. Their Trader Workstation (TWS) platform is highly customizable and offers advanced trading tools, including algorithmic trading and risk management features.
-
E*TRADE: ETRADE provides a user-friendly platform with comprehensive research tools and educational resources. Their Power ETRADE platform is designed for active traders and offers advanced charting, technical analysis, and real-time data.
-
Charles Schwab: Charles Schwab offers a robust trading platform with a wide range of futures products. Their StreetSmart Edge platform provides advanced charting tools, technical analysis, and real-time data. Schwab is also known for its excellent customer service and educational resources.
-
NinjaTrader: NinjaTrader is a popular choice among futures traders for its advanced charting and analysis tools. The platform offers a wide range of technical indicators, automated trading capabilities, and competitive commission rates. NinjaTrader also provides access to a large community of traders and educational resources.
-
TradeStation: TradeStation is known for its powerful trading platform and advanced analytical tools. They offer a wide range of futures products and competitive commission rates. TradeStation's platform is highly customizable and provides access to real-time data, advanced charting, and technical analysis.
Example Scenario
Consider a trader who wants to trade crude oil futures. They choose Interactive Brokers for its low-cost trading and extensive range of futures contracts. The trader uses the Trader Workstation (TWS) platform to analyze crude oil price charts and identify trading opportunities. They place a market order to buy one crude oil futures contract and set stop-loss and take-profit levels based on their analysis. The trader continuously monitors the market and adjusts their orders as needed, ultimately achieving a profitable trade.
Advantages of Choosing the Right Broker
- Advanced Trading Tools: The best brokers offer advanced trading platforms with powerful charting, technical analysis, and real-time data.
- Competitive Fees: Low commission rates and competitive fees can significantly impact overall trading profitability.
- Customer Support: Excellent customer support ensures that traders can get help when needed, improving their trading experience.
- Educational Resources: Access to educational resources and research tools can help traders improve their skills and make more informed decisions.
By choosing the right broker, traders can enhance their futures trading experience and increase their chances of success. It's important to consider factors such as trading platform features, commission rates, customer support, and educational resources when selecting a broker for futures trading.
Fundamental Analysis
What is Fundamental Analysis?
Fundamental analysis is a method of evaluating the intrinsic value of an asset, such as a stock, by examining related economic, financial, and other qualitative and quantitative factors. The goal of fundamental analysis is to determine whether an asset is overvalued or undervalued by the market, and to make investment decisions based on this assessment.
Key Components of Fundamental Analysis
-
Economic Analysis: This involves analyzing the overall economic environment, including factors such as GDP growth, inflation rates, interest rates, and employment levels. Economic conditions can have a significant impact on the performance of individual companies and industries.
-
Industry Analysis: This involves examining the specific industry in which a company operates. Factors to consider include industry growth rates, competitive dynamics, regulatory environment, and technological advancements. Understanding the industry context helps in assessing a company's potential for growth and profitability.
-
Company Analysis: This involves a detailed examination of a company's financial statements, management team, business model, and competitive position. Key financial metrics to analyze include revenue, earnings, profit margins, return on equity, and debt levels. Qualitative factors such as management quality, corporate governance, and brand strength are also important.
Financial Statements
Fundamental analysis relies heavily on the analysis of financial statements, which provide a comprehensive view of a company's financial health. The three main financial statements are:
-
Income Statement: This statement provides information about a company's revenues, expenses, and profits over a specific period. Key metrics to analyze include gross profit, operating income, and net income.
-
Balance Sheet: This statement provides a snapshot of a company's assets, liabilities, and shareholders' equity at a specific point in time. Key metrics to analyze include current assets, current liabilities, long-term debt, and equity.
-
Cash Flow Statement: This statement provides information about a company's cash inflows and outflows over a specific period. Key metrics to analyze include operating cash flow, investing cash flow, and financing cash flow.
Valuation Methods
Fundamental analysis involves various valuation methods to estimate the intrinsic value of an asset. Some common valuation methods include:
-
Discounted Cash Flow (DCF) Analysis: This method involves estimating the present value of a company's future cash flows. The DCF analysis requires making assumptions about future revenue growth, profit margins, and discount rates.
-
Price-to-Earnings (P/E) Ratio: This ratio compares a company's current stock price to its earnings per share (EPS). A high P/E ratio may indicate that a stock is overvalued, while a low P/E ratio may indicate that it is undervalued.
-
Price-to-Book (P/B) Ratio: This ratio compares a company's current stock price to its book value per share. The book value is the value of a company's assets minus its liabilities. A low P/B ratio may indicate that a stock is undervalued.
-
Dividend Discount Model (DDM): This method involves estimating the present value of a company's future dividend payments. The DDM is particularly useful for valuing companies with a stable dividend payout history.
Conclusion
Fundamental analysis is a comprehensive approach to evaluating the intrinsic value of an asset by examining economic, industry, and company-specific factors. By analyzing financial statements and using various valuation methods, investors can make informed decisions about whether to buy, hold, or sell an asset. While fundamental analysis requires a thorough understanding of financial concepts and data, it provides valuable insights into the true worth of an investment.
Stocks
Key Financial Ratios
Price to Earnings Ratio (P/E)
The Price to Earnings Ratio (P/E) is a fundamental valuation tool that compares a company's current share price to its earnings per share (EPS). It is expressed as:
$$ P/E = \frac{Price\ per\ Share}{Earnings\ per\ Share} $$
A high P/E ratio might suggest that a stock is overvalued or that investors anticipate significant growth. Conversely, a low P/E ratio could indicate undervaluation or potential challenges faced by the company.
Price to Book Ratio (P/B)
The Price to Book Ratio (P/B) evaluates a company's market value against its book value. It is determined by:
$$ P/B = \frac{Market\ Price\ per\ Share}{Book\ Value\ per\ Share} $$
The book value represents the net asset value, calculated as total assets minus intangible assets and liabilities. A lower P/B ratio may signal undervaluation, while a higher ratio could imply overvaluation.
Debt to Equity Ratio (D/E)
The Debt to Equity Ratio (D/E) assesses a company's financial leverage by comparing its total liabilities to shareholder equity. It is calculated as:
$$ D/E = \frac{Total\ Liabilities}{Shareholder\ Equity} $$
A higher D/E ratio indicates greater reliance on debt for financing, which can be risky if not managed well. A lower ratio suggests a more conservative financial strategy.
Return on Equity (ROE)
Return on Equity (ROE) measures a company's profitability by comparing net income to shareholder equity. It is expressed as:
$$ ROE = \frac{Net\ Income}{Shareholder\ Equity} $$
A higher ROE signifies effective profit generation from equity investments, serving as a crucial indicator of financial performance and efficiency.
Current Ratio
The Current Ratio is a liquidity metric that evaluates a company's ability to meet short-term obligations with its current assets. It is calculated as:
$$ Current\ Ratio = \frac{Current\ Assets}{Current\ Liabilities} $$
A higher current ratio suggests strong short-term financial health, while a lower ratio may indicate potential liquidity challenges.
Quick Ratio
The Quick Ratio, or acid-test ratio, is a stringent liquidity measure that excludes inventory from current assets. It is calculated as:
$$ Quick\ Ratio = \frac{Current\ Assets - Inventory}{Current\ Liabilities} $$
A higher quick ratio indicates the ability to meet short-term obligations without relying on inventory sales.
Dividend Yield
The Dividend Yield reflects the annual dividend income relative to the market price per share. It is calculated as:
$$ Dividend\ Yield = \frac{Annual\ Dividends\ per\ Share}{Price\ per\ Share} $$
A higher dividend yield suggests a company is returning more income to shareholders, appealing to income-focused investors.
Earnings Per Share (EPS)
Earnings Per Share (EPS) is a critical profitability metric indicating the profit generated per share of stock. It is calculated as:
$$ EPS = \frac{Net\ Income - Dividends\ on\ Preferred\ Stock}{Average\ Outstanding\ Shares} $$
A higher EPS reflects better profitability and is a key factor for investors assessing financial health.
Price to Sales Ratio (P/S)
The Price to Sales Ratio (P/S) compares a company's market capitalization to its total sales or revenue. It is expressed as:
$$ P/S = \frac{Market\ Capitalization}{Total\ Sales} $$
A lower P/S ratio may indicate undervaluation, while a higher ratio could suggest overvaluation, especially useful for companies with minimal earnings.
Conclusion
Analyzing these financial ratios offers valuable insights into a company's valuation, financial health, and performance. Investors leverage these metrics to make informed decisions and compare companies within the same industry.
Options
Black-Scholes Model
The Black-Scholes model is a renowned mathematical model used to price options and other financial derivatives. Developed by Fischer Black and Myron Scholes, the model was first published in 1973. It assumes that the underlying asset's price follows a geometric Brownian motion and uses a no-arbitrage approach to derive the option's price.
Greeks
The Greeks are a set of mathematical tools used in the Black-Scholes model to measure the sensitivity of an option's price to changes in various parameters. The most common Greeks include delta, gamma, theta, vega, and rho.
Detailed Explanation of Greeks
The Greeks are essential tools for options traders, providing insights into how different factors impact the price of an option. Here are the most common Greeks and their significance:
-
Delta ($\Delta$): Delta measures the sensitivity of an option's price to changes in the price of the underlying asset. It represents the rate of change of the option's price with respect to a $1 change in the underlying asset's price. For call options, delta ranges from 0 to 1, while for put options, delta ranges from -1 to 0. A higher delta indicates greater sensitivity to price changes in the underlying asset.
-
Gamma ($\Gamma$): Gamma measures the rate of change of delta with respect to changes in the underlying asset's price. It indicates how much the delta of an option will change for a $1 change in the underlying asset's price. Gamma is highest for at-the-money options and decreases as the option moves further in-the-money or out-of-the-money. High gamma values indicate that delta is more sensitive to price changes in the underlying asset.
-
Theta ($\Theta$): Theta measures the sensitivity of an option's price to the passage of time, also known as time decay. It represents the rate at which the option's price decreases as time to expiration approaches. Theta is typically negative for both call and put options, as the value of options erodes over time. Options with shorter time to expiration have higher theta values, indicating faster time decay.
-
Vega ($\nu$): Vega measures the sensitivity of an option's price to changes in the volatility of the underlying asset. It represents the amount by which the option's price will change for a $1%$ change in the underlying asset's volatility. Higher vega values indicate that the option's price is more sensitive to changes in volatility. Vega is highest for at-the-money options and decreases as the option moves further in-the-money or out-of-the-money.
-
Rho ($\rho$): Rho measures the sensitivity of an option's price to changes in interest rates. It represents the amount by which the option's price will change for a $1%$ change in the risk-free interest rate. For call options, rho is positive, indicating that an increase in interest rates will increase the option's price. For put options, rho is negative, indicating that an increase in interest rates will decrease the option's price.
Practical Applications of Greeks
Understanding the Greeks is crucial for options traders, as they help in managing risk and making informed trading decisions. Here are some practical applications:
- Hedging: Traders use delta to hedge their positions by ensuring that the overall delta of their portfolio is neutral, reducing exposure to price movements in the underlying asset.
- Adjusting Positions: Gamma helps traders understand how their delta will change with price movements, allowing them to adjust their positions accordingly.
- Time Decay Management: Theta is important for traders who sell options, as it helps them understand how the value of their options will erode over time.
- Volatility Trading: Vega is crucial for traders who speculate on changes in volatility, as it helps them gauge the impact of volatility changes on their options' prices.
- Interest Rate Impact: Rho is useful for understanding how changes in interest rates will affect the value of options, particularly for long-term options.
By mastering the Greeks, options traders can better navigate the complexities of the options market and enhance their trading strategies.
Option Strategies
Option strategies are various combinations of buying and selling options to achieve specific financial goals, such as hedging risk, generating income, or speculating on price movements. Here are some common option strategies:
1. Covered Call
A covered call involves holding a long position in an underlying asset and selling a call option on that same asset. This strategy generates income from the option premium but limits the upside potential if the asset's price rises significantly.
2. Protective Put
A protective put involves holding a long position in an underlying asset and buying a put option on that same asset. This strategy provides downside protection, as the put option gains value if the asset's price falls.
3. Straddle
A straddle involves buying both a call option and a put option with the same strike price and expiration date. This strategy profits from significant price movements in either direction, making it suitable for volatile markets.
4. Strangle
A strangle involves buying a call option and a put option with different strike prices but the same expiration date. This strategy is similar to a straddle but requires a larger price movement to be profitable.
5. Bull Call Spread
A bull call spread involves buying a call option with a lower strike price and selling a call option with a higher strike price. This strategy profits from a moderate rise in the underlying asset's price while limiting potential losses.
6. Bear Put Spread
A bear put spread involves buying a put option with a higher strike price and selling a put option with a lower strike price. This strategy profits from a moderate decline in the underlying asset's price while limiting potential losses.
7. Iron Condor
An iron condor involves selling an out-of-the-money call option and an out-of-the-money put option while simultaneously buying a further out-of-the-money call option and put option. This strategy profits from low volatility and a narrow price range for the underlying asset.
8. Butterfly Spread
A butterfly spread involves buying a call option (or put option) with a lower strike price, selling two call options (or put options) with a middle strike price, and buying a call option (or put option) with a higher strike price. This strategy profits from low volatility and a stable price for the underlying asset.
9. Calendar Spread
A calendar spread involves buying and selling options with the same strike price but different expiration dates. This strategy profits from changes in volatility and the passage of time.
10. Collar
A collar involves holding a long position in an underlying asset, buying a protective put option, and selling a covered call option. This strategy provides downside protection while limiting upside potential.
Each of these strategies has its own risk and reward profile, making them suitable for different market conditions and investment goals. Understanding and selecting the appropriate option strategy can help investors manage risk and enhance returns.
11. Long Call
A long call involves buying a call option with the expectation that the underlying asset's price will rise above the strike price before the option expires. This strategy offers unlimited profit potential with limited risk, as the maximum loss is the premium paid for the option.
12. Long Put
A long put involves buying a put option with the expectation that the underlying asset's price will fall below the strike price before the option expires. This strategy offers significant profit potential with limited risk, as the maximum loss is the premium paid for the option.
13. Short Call
A short call involves selling a call option without owning the underlying asset. This strategy generates income from the option premium but carries unlimited risk if the asset's price rises significantly.
14. Short Put
A short put involves selling a put option with the expectation that the underlying asset's price will remain above the strike price. This strategy generates income from the option premium but carries significant risk if the asset's price falls below the strike price.
15. Diagonal Spread
A diagonal spread involves buying and selling options with different strike prices and expiration dates. This strategy combines elements of both calendar and vertical spreads, allowing traders to profit from changes in volatility and price movements.
16. Ratio Spread
A ratio spread involves buying a certain number of options and selling a different number of options with the same expiration date but different strike prices. This strategy can be used to profit from moderate price movements while managing risk.
17. Box Spread
A box spread involves combining a bull call spread and a bear put spread with the same strike prices and expiration dates. This strategy is used to lock in a risk-free profit when there is a discrepancy in option pricing.
18. Synthetic Long Stock
A synthetic long stock involves buying a call option and selling a put option with the same strike price and expiration date. This strategy mimics the payoff of holding the underlying asset without actually owning it.
19. Synthetic Short Stock
A synthetic short stock involves selling a call option and buying a put option with the same strike price and expiration date. This strategy mimics the payoff of shorting the underlying asset without actually shorting it.
20. Iron Butterfly
An iron butterfly involves selling an at-the-money call option and an at-the-money put option while simultaneously buying an out-of-the-money call option and an out-of-the-money put option. This strategy profits from low volatility and a stable price for the underlying asset.
By understanding and utilizing these additional option strategies, traders can further diversify their approaches to managing risk and capitalizing on market opportunities. Each strategy has its own unique characteristics and potential benefits, making it essential for traders to carefully consider their objectives and market conditions when selecting an appropriate strategy.
How to Trade Options
Trading options involves several steps, from understanding the market to executing trades. Here is a step-by-step guide on how to trade options:
Step 1: Understand the Basics
Before trading options, it's essential to understand the basics of how options contracts work. This includes knowing the key terms, such as strike price, expiration date, premium, and the difference between call and put options. Familiarize yourself with the different types of options strategies available, such as covered calls, protective puts, and spreads.
Step 2: Choose an Options Broker
To trade options, you need to open an account with an options broker. Look for a broker that offers a user-friendly trading platform, competitive fees, and reliable customer support. Ensure the broker is regulated and has a good reputation in the industry.
Step 3: Develop a Trading Plan
A trading plan is crucial for success in options trading. Your plan should outline your trading goals, risk tolerance, and strategies. Decide on the types of options contracts you want to trade and the timeframes you will focus on. Set clear entry and exit points, as well as stop-loss and take-profit levels.
Step 4: Analyze the Market
Conduct thorough market analysis to identify trading opportunities. Use technical analysis tools, such as charts, indicators, and patterns, to analyze price movements. Additionally, consider fundamental analysis by keeping track of economic news, reports, and events that may impact the options markets.
Step 5: Place Your Trade
Once you have identified a trading opportunity, place your trade through your broker's trading platform. Specify the contract you want to trade, the number of contracts, and the order type (e.g., market order, limit order). Ensure you have sufficient margin in your account to cover the trade.
Step 6: Monitor and Manage Your Trade
After placing your trade, continuously monitor the market and manage your position. Adjust your stop-loss and take-profit levels as needed to protect your profits and limit losses. Be prepared to exit the trade if the market moves against you or if your target is reached.
Step 7: Review and Learn
After closing your trade, review the outcome and analyze your performance. Identify what worked well and what could be improved. Use this information to refine your trading plan and strategies for future trades.
Example Scenario
Consider a trader who wants to trade call options on a tech stock. Here is how they might approach the trade:
- Understand the Basics: The trader learns that a call option gives them the right to buy the stock at a specific price before the expiration date.
- Choose an Options Broker: The trader opens an account with a reputable broker that offers competitive fees and a robust trading platform.
- Develop a Trading Plan: The trader sets a goal to profit from short-term price movements in the tech stock and decides to use technical analysis for entry and exit points.
- Analyze the Market: The trader analyzes the stock's price charts and identifies a bullish trend supported by positive earnings reports.
- Place the Trade: The trader places a market order to buy call options with a strike price close to the current stock price.
- Monitor and Manage: The trader sets a stop-loss order below a recent support level and a take-profit order at a higher resistance level. They monitor the trade and adjust the orders as needed.
- Review and Learn: After closing the trade, the trader reviews the outcome and notes that the bullish trend continued, resulting in a profitable trade. They use this experience to refine their future trading strategies.
Conclusion
Trading options can be a rewarding endeavor, but it requires a solid understanding of the market, a well-developed trading plan, and disciplined execution. By following these steps and continuously learning from your experiences, you can improve your chances of success in the options markets.
Where to Get Good Options Data
Access to reliable and accurate options data is crucial for making informed trading decisions. Here are some sources where you can get good options data:
-
Brokerage Platforms: Many brokerage platforms provide comprehensive options data, including real-time quotes, historical data, and analytical tools. Examples include TD Ameritrade, E*TRADE, and Interactive Brokers.
-
Financial News Websites: Websites like Yahoo Finance, Google Finance, and Bloomberg offer options data along with news, analysis, and market insights.
-
Market Data Providers: Companies like Cboe Global Markets, Nasdaq, and NYSE provide extensive options data, including real-time and historical data, market statistics, and analytics.
-
Data Aggregators: Services like Options Data Warehouse and Quandl aggregate options data from multiple sources, providing a centralized platform for accessing comprehensive data sets.
-
Specialized Tools: Tools like OptionVue, LiveVol, and ThinkOrSwim offer advanced options analysis and data visualization features, catering to both retail and professional traders.
Brokers with Automated Trading
Automated trading can help you execute trades more efficiently and take advantage of market opportunities in real-time. Here are some brokers that offer automated trading capabilities:
-
Interactive Brokers: Interactive Brokers provides a robust API that allows traders to automate their trading strategies using various programming languages, including Python, Java, and C++.
-
TD Ameritrade: TD Ameritrade's thinkorswim platform offers automated trading through its thinkScript language, enabling traders to create custom scripts and strategies.
-
E*TRADE: E*TRADE offers automated trading through its API, allowing traders to develop and implement automated trading strategies using their preferred programming languages.
-
TradeStation: TradeStation provides a powerful platform for automated trading, with EasyLanguage for strategy development and integration with various third-party tools and APIs.
-
Alpaca: Alpaca is a commission-free broker that offers a user-friendly API for automated trading, making it accessible for both beginner and experienced traders.
-
QuantConnect: QuantConnect is a cloud-based algorithmic trading platform that integrates with multiple brokers, including Interactive Brokers and Tradier, allowing traders to develop and deploy automated trading strategies.
By leveraging these sources for options data and brokers with automated trading capabilities, you can enhance your trading strategies and improve your overall trading performance.
Futures
What are Futures?
Futures are financial contracts obligating the buyer to purchase an asset or the seller to sell an asset at a predetermined future date and price. These contracts are standardized and traded on futures exchanges. Futures can be used for hedging or speculative purposes.
Key Features of Futures
- Standardization: Futures contracts are standardized in terms of quantity, quality, and delivery time, making them easily tradable on exchanges.
- Leverage: Futures allow traders to control large positions with a relatively small amount of capital, providing the potential for significant gains or losses.
- Margin Requirements: Traders are required to deposit a margin, which is a fraction of the contract's value, to enter into a futures position. This margin acts as a security deposit to cover potential losses.
- Settlement: Futures contracts can be settled either by physical delivery of the underlying asset or by cash settlement, depending on the terms of the contract.
Types of Futures Contracts
- Commodity Futures: These contracts involve physical commodities such as oil, gold, wheat, and corn. They are commonly used by producers and consumers to hedge against price fluctuations.
- Financial Futures: These contracts involve financial instruments such as currencies, interest rates, and stock indices. They are often used by investors and institutions to manage financial risk.
- Index Futures: These contracts are based on stock market indices like the S&P 500 or the Dow Jones Industrial Average. They allow traders to speculate on the overall direction of the market.
- Currency Futures: These contracts involve the exchange of one currency for another at a future date. They are used by businesses and investors to hedge against currency risk.
Example Scenario
Consider a wheat farmer who wants to lock in a price for their crop to protect against the risk of falling prices. The farmer can sell wheat futures contracts, agreeing to deliver a specified quantity of wheat at a predetermined price on a future date. If the market price of wheat falls, the farmer is protected because they have locked in a higher price through the futures contract.
Advantages of Futures
- Risk Management: Futures allow businesses and investors to hedge against price fluctuations, reducing uncertainty and managing risk.
- Liquidity: Futures markets are highly liquid, allowing traders to enter and exit positions easily.
- Price Discovery: Futures markets provide valuable information about future price expectations, helping businesses and investors make informed decisions.
- Diversification: Futures offer opportunities to diversify investment portfolios by gaining exposure to different asset classes.
Conclusion
Futures are powerful financial instruments that provide opportunities for hedging and speculation. By understanding the key features, types, and advantages of futures, traders and investors can effectively manage risk and capitalize on market opportunities. Whether used for hedging against price fluctuations or speculating on market movements, futures play a crucial role in the global financial markets.
Difference Between Futures and Options
Futures and options are both financial derivatives that allow traders to speculate on the price movements of underlying assets. However, there are key differences between the two:
Futures
- Obligation: Futures contracts obligate the buyer to purchase and the seller to sell the underlying asset at a predetermined price and date.
- Standardization: Futures contracts are standardized in terms of quantity, quality, and delivery time, making them easily tradable on exchanges.
- Leverage: Futures allow traders to control large positions with a relatively small amount of capital, providing the potential for significant gains or losses.
- Margin Requirements: Traders are required to deposit a margin, which is a fraction of the contract's value, to enter into a futures position. This margin acts as a security deposit to cover potential losses.
- Settlement: Futures contracts can be settled either by physical delivery of the underlying asset or by cash settlement, depending on the terms of the contract.
Options
- Right, Not Obligation: Options contracts give the buyer the right, but not the obligation, to buy (call option) or sell (put option) the underlying asset at a predetermined price and date.
- Premium: The buyer of an options contract pays a premium to the seller for the right to exercise the option. This premium is the maximum loss the buyer can incur.
- Leverage: Options also provide leverage, allowing traders to control large positions with a relatively small amount of capital. However, the potential loss for the buyer is limited to the premium paid.
- Types of Options: There are two main types of options: call options and put options. Call options give the buyer the right to buy the underlying asset, while put options give the buyer the right to sell the underlying asset.
- Expiration: Options contracts have an expiration date, after which the option becomes worthless if not exercised.
Key Differences
- Obligation vs. Right: Futures contracts create an obligation for both parties, while options contracts provide the buyer with a right without obligation.
- Risk and Reward: In futures, both parties face unlimited risk and reward potential. In options, the buyer's risk is limited to the premium paid, while the seller faces unlimited risk.
- Cost: Futures require margin deposits, while options require the payment of a premium.
- Flexibility: Options offer more flexibility due to the right to exercise, while futures are more rigid with mandatory settlement.
Example Scenario
Consider an investor who wants to speculate on the price of gold. They can choose between futures and options:
- Futures: The investor buys a gold futures contract, obligating them to purchase gold at a specified price on a future date. If the price of gold rises, the investor profits. If the price falls, the investor incurs a loss.
- Options: The investor buys a call option on gold, giving them the right to buy gold at a specified price on or before the expiration date. If the price of gold rises, the investor can exercise the option and profit. If the price falls, the investor's loss is limited to the premium paid for the option.
Conclusion
Both futures and options are valuable tools for traders and investors to manage risk and speculate on price movements. Understanding the differences between these derivatives is crucial for making informed trading decisions. Futures provide an obligation to buy or sell, while options offer the right without obligation, each with its own risk and reward profile.
How to Trade Futures
Trading futures involves several steps, from understanding the market to executing trades. Here is a step-by-step guide on how to trade futures:
Step 1: Understand the Basics
Before trading futures, it's essential to understand the basics of how futures contracts work. This includes knowing the key terms, such as contract size, expiration date, and margin requirements. Familiarize yourself with the different types of futures contracts available, such as commodities, financials, and indices.
Step 2: Choose a Futures Broker
To trade futures, you need to open an account with a futures broker. Look for a broker that offers a user-friendly trading platform, competitive fees, and reliable customer support. Ensure the broker is regulated and has a good reputation in the industry.
Step 3: Develop a Trading Plan
A trading plan is crucial for success in futures trading. Your plan should outline your trading goals, risk tolerance, and strategies. Decide on the types of futures contracts you want to trade and the timeframes you will focus on. Set clear entry and exit points, as well as stop-loss and take-profit levels.
Step 4: Analyze the Market
Conduct thorough market analysis to identify trading opportunities. Use technical analysis tools, such as charts, indicators, and patterns, to analyze price movements. Additionally, consider fundamental analysis by keeping track of economic news, reports, and events that may impact the futures markets.
Step 5: Place Your Trade
Once you have identified a trading opportunity, place your trade through your broker's trading platform. Specify the contract you want to trade, the number of contracts, and the order type (e.g., market order, limit order). Ensure you have sufficient margin in your account to cover the trade.
Step 6: Monitor and Manage Your Trade
After placing your trade, continuously monitor the market and manage your position. Adjust your stop-loss and take-profit levels as needed to protect your profits and limit losses. Be prepared to exit the trade if the market moves against you or if your target is reached.
Step 7: Review and Learn
After closing your trade, review the outcome and analyze your performance. Identify what worked well and what could be improved. Use this information to refine your trading plan and strategies for future trades.
Example Scenario
Consider a trader who wants to trade crude oil futures. Here is how they might approach the trade:
- Understand the Basics: The trader learns that a crude oil futures contract represents 1,000 barrels of oil and has specific expiration dates.
- Choose a Futures Broker: The trader opens an account with a reputable broker that offers competitive fees and a robust trading platform.
- Develop a Trading Plan: The trader sets a goal to profit from short-term price movements in crude oil and decides to use technical analysis for entry and exit points.
- Analyze the Market: The trader analyzes crude oil price charts and identifies a bullish trend supported by positive economic news.
- Place the Trade: The trader places a market order to buy one crude oil futures contract at the current price.
- Monitor and Manage: The trader sets a stop-loss order below a recent support level and a take-profit order at a higher resistance level. They monitor the trade and adjust the orders as needed.
- Review and Learn: After closing the trade, the trader reviews the outcome and notes that the bullish trend continued, resulting in a profitable trade. They use this experience to refine their future trading strategies.
Conclusion
Trading futures can be a rewarding endeavor, but it requires a solid understanding of the market, a well-developed trading plan, and disciplined execution. By following these steps and continuously learning from your experiences, you can improve your chances of success in the futures markets.
Crypto
Proof of Work
Proof of Work (PoW) is a consensus mechanism used in blockchain networks to validate transactions and secure the network. It requires participants, known as miners, to solve complex mathematical puzzles to add new blocks to the blockchain. The first miner to solve the puzzle gets the right to add the block and is rewarded with cryptocurrency.
How Proof of Work Works
- Transaction Collection: Miners collect and verify transactions from the network, grouping them into a block.
- Puzzle Solving: Miners compete to solve a cryptographic puzzle, which involves finding a nonce (a random number) that, when hashed with the block's data, produces a hash that meets the network's difficulty target.
- Block Validation: The first miner to solve the puzzle broadcasts the solution to the network. Other miners validate the solution and the block.
- Block Addition: Once validated, the block is added to the blockchain, and the miner receives a reward, typically in the form of newly minted cryptocurrency and transaction fees.
- Difficulty Adjustment: The network periodically adjusts the difficulty of the puzzle to ensure a consistent block generation time, usually around 10 minutes for Bitcoin.
Key Concepts
- Hash Function: A cryptographic function that converts input data into a fixed-size string of characters, which appears random. Bitcoin uses the SHA-256 hash function.
- Nonce: A random number that miners change to find a hash that meets the difficulty target.
- Difficulty Target: A value that determines how hard it is to find a valid hash. The lower the target, the more difficult the puzzle.
- Block Reward: The incentive miners receive for adding a new block to the blockchain. This reward decreases over time in events known as "halvings."
Advantages of Proof of Work
- Security: PoW provides strong security by making it computationally expensive to alter the blockchain. An attacker would need more computational power than the rest of the network combined to succeed.
- Decentralization: PoW promotes decentralization by allowing anyone with the necessary hardware to participate in mining, reducing the risk of central control.
- Proven Track Record: PoW has been successfully used by Bitcoin and other cryptocurrencies for over a decade, demonstrating its effectiveness in securing blockchain networks.
Disadvantages of Proof of Work
- Energy Consumption: PoW requires significant computational power, leading to high energy consumption and environmental concerns.
- Centralization Risk: Over time, mining can become concentrated in regions with cheap electricity or among entities with access to specialized hardware, potentially reducing decentralization.
- Scalability: PoW can limit the scalability of blockchain networks due to the time and resources required to solve puzzles and add new blocks.
Conclusion
Proof of Work is a foundational consensus mechanism in blockchain technology, providing security and decentralization through computational effort. While it has proven effective, its energy consumption and scalability challenges have led to the exploration of alternative mechanisms like Proof of Stake (PoS). Nonetheless, PoW remains a critical component of many blockchain networks, ensuring the integrity and trustworthiness of decentralized systems.
Proof of Stake
Proof of Stake (PoS) is an alternative consensus mechanism to Proof of Work (PoW) used in blockchain networks to validate transactions and secure the network. Instead of relying on computational power to solve complex puzzles, PoS selects validators based on the number of coins they hold and are willing to "stake" as collateral.
How Proof of Stake Works
- Validator Selection: Validators are chosen to create new blocks and validate transactions based on the number of coins they hold and lock up as collateral. The more coins a validator stakes, the higher their chances of being selected.
- Block Creation: The selected validator creates a new block and adds it to the blockchain. This process is known as "forging" or "minting" rather than "mining."
- Transaction Validation: Other validators in the network verify the new block. If the block is valid, it is added to the blockchain, and the validator receives a reward.
- Slashing: If a validator is found to act maliciously or validate fraudulent transactions, a portion of their staked coins can be forfeited as a penalty. This mechanism is known as "slashing" and helps maintain network security and integrity.
Key Concepts
- Staking: The process of locking up a certain amount of cryptocurrency to participate in the validation process. Validators are incentivized to act honestly to avoid losing their staked coins.
- Validator: A participant in the network who is responsible for creating new blocks and validating transactions. Validators are chosen based on the amount of cryptocurrency they stake.
- Slashing: A penalty mechanism that confiscates a portion of a validator's staked coins if they are found to act maliciously or validate fraudulent transactions.
- Delegated Proof of Stake (DPoS): A variation of PoS where stakeholders vote for a small number of delegates to validate transactions and create new blocks on their behalf. This system aims to improve efficiency and scalability.
Advantages of Proof of Stake
- Energy Efficiency: PoS is significantly more energy-efficient than PoW, as it does not require extensive computational power to validate transactions and create new blocks.
- Security: PoS provides strong security by aligning the interests of validators with the network. Validators are incentivized to act honestly to avoid losing their staked coins.
- Decentralization: PoS promotes decentralization by allowing a broader range of participants to become validators, as it does not require specialized hardware or significant energy consumption.
- Scalability: PoS can improve the scalability of blockchain networks by reducing the time and resources required to validate transactions and create new blocks.
Disadvantages of Proof of Stake
- Wealth Concentration: PoS can lead to wealth concentration, as validators with more coins have a higher chance of being selected to create new blocks and earn rewards.
- Initial Distribution: The initial distribution of coins can impact the fairness and decentralization of the network, as early adopters or large holders may have more influence.
- Complexity: PoS mechanisms can be more complex to implement and understand compared to PoW, requiring careful design to ensure security and fairness.
Conclusion
Proof of Stake is a promising alternative to Proof of Work, offering significant improvements in energy efficiency, security, and scalability. By selecting validators based on the number of coins they stake, PoS aligns the interests of participants with the network's integrity. While it has its challenges, such as potential wealth concentration and complexity, PoS continues to gain traction as a viable consensus mechanism for blockchain networks, driving innovation and sustainability in the cryptocurrency space.
Solana
Important Concepts and Token Economics of Solana
Solana is a high-performance blockchain platform designed for decentralized applications and crypto-currencies. It aims to provide scalability without compromising decentralization and security. Here are some important concepts and token economics of Solana:
Important Concepts
-
Proof of History (PoH): Proof of History is a unique consensus mechanism used by Solana to timestamp transactions before they are included in the blockchain. PoH creates a historical record that proves that an event has occurred at a specific moment in time. This allows the network to order transactions and improve efficiency.
-
Tower BFT: Tower Byzantine Fault Tolerance (BFT) is Solana's consensus algorithm that leverages PoH as a cryptographic clock to achieve consensus. Tower BFT reduces the communication overhead and latency, enabling faster transaction finality.
-
Turbine: Turbine is Solana's block propagation protocol. It breaks data into smaller packets and transmits them across the network in a way that reduces bandwidth requirements and increases the speed of data transmission.
-
Gulf Stream: Gulf Stream is Solana's mempool-less transaction forwarding protocol. It pushes transaction caching and forwarding to the edge of the network, allowing validators to execute transactions ahead of time, reducing confirmation times and improving network efficiency.
-
Sealevel: Sealevel is Solana's parallel smart contract runtime. It allows multiple smart contracts to run in parallel, leveraging the multi-core processors in modern hardware to achieve high throughput.
-
Pipelining: Pipelining is a process used by Solana to optimize the validation process. It involves a series of stages where different parts of transaction validation are handled by different hardware units, improving overall throughput.
-
Cloudbreak: Cloudbreak is Solana's horizontally-scalable accounts database. It allows the network to handle a large number of accounts and transactions efficiently by distributing the data across multiple storage devices.
-
Archivers: Archivers are nodes in the Solana network responsible for storing data. They offload the storage burden from validators, ensuring that the blockchain remains lightweight and efficient.
Token Economics
-
SOL Token: SOL is the native cryptocurrency of the Solana network. It is used to pay for transaction fees, participate in the network's consensus mechanism, and interact with smart contracts.
-
Staking: SOL token holders can stake their tokens to become validators or delegate their tokens to other validators. Staking helps secure the network and participants earn rewards in the form of additional SOL tokens.
-
Inflation: Solana has an inflationary supply model, where new SOL tokens are minted and distributed as staking rewards. The initial inflation rate is set at 8% per year and is designed to decrease over time, eventually stabilizing at around 1.5% per year.
-
Transaction Fees: Transaction fees on the Solana network are paid in SOL tokens. These fees are relatively low compared to other blockchain networks, making Solana an attractive platform for high-frequency and micro-transactions.
-
Burn Mechanism: A portion of the transaction fees collected on the Solana network is burned, reducing the total supply of SOL tokens over time. This deflationary mechanism helps counteract the inflationary supply model and can potentially increase the value of SOL tokens.
-
Ecosystem Incentives: Solana has various incentive programs to encourage the development and growth of its ecosystem. These include grants, hackathons, and partnerships aimed at attracting developers, projects, and users to the platform.
Solana's innovative technology and well-designed token economics make it a promising platform for scalable and efficient decentralized applications. Its focus on high throughput, low latency, and low transaction costs positions it as a strong contender in the blockchain space.
Databases & Data Engineering
Database systems and data engineering concepts for storing, querying, and managing data at scale.
Topics Covered
Relational Databases
- SQL - SQL fundamentals, queries, joins, indexes, transactions
- PostgreSQL - Advanced PostgreSQL features, JSON support, performance tuning
- SQLite - Lightweight embedded database for applications
- DuckDB - Analytical database for data analysis and OLAP queries
NoSQL Databases
- NoSQL - NoSQL databases overview, types, and use cases
- MongoDB - Document-oriented NoSQL database with rich query language
- Redis - In-memory data store for caching, pub/sub, and real-time applications
Message Queues & Event Streaming
- Apache Kafka - Distributed event streaming platform for high-throughput data pipelines
Database Concepts
- Data Modeling: Schema design, normalization, relationships
- Caching: In-memory stores, cache invalidation strategies
- Data Pipelines: ETL, streaming, batch processing
- Database Optimization: Query optimization, indexing strategies
Navigation
Use the menu to explore each topic in depth.
SQL (Structured Query Language)
Overview
SQL is the standard language for querying and managing relational databases. Used by PostgreSQL, MySQL, SQL Server, Oracle, and others.
Basic Queries
SELECT
SELECT column1, column2 FROM table WHERE condition;
SELECT * FROM users WHERE age > 18;
SELECT DISTINCT city FROM customers;
INSERT
INSERT INTO users (name, email) VALUES ('John', 'john@example.com');
INSERT INTO users VALUES (1, 'John', 'john@example.com');
UPDATE
UPDATE users SET age = 30 WHERE id = 1;
UPDATE products SET price = price * 1.1;
DELETE
DELETE FROM users WHERE id = 1;
DELETE FROM logs WHERE created_at < '2023-01-01';
Joins
-- INNER JOIN: Only matching rows
SELECT u.name, o.order_id
FROM users u INNER JOIN orders o ON u.id = o.user_id;
-- LEFT JOIN: All from left table
SELECT u.name, o.order_id
FROM users u LEFT JOIN orders o ON u.id = o.user_id;
-- RIGHT JOIN: All from right table
SELECT u.name, o.order_id
FROM users u RIGHT JOIN orders o ON u.id = o.user_id;
-- FULL OUTER JOIN: All rows
SELECT u.name, o.order_id
FROM users u FULL OUTER JOIN orders o ON u.id = o.user_id;
Aggregation
SELECT COUNT(*) FROM users;
SELECT AVG(price) FROM products;
SELECT SUM(amount) FROM transactions WHERE status = 'completed';
SELECT MAX(salary) FROM employees;
-- GROUP BY
SELECT department, COUNT(*) FROM employees GROUP BY department;
SELECT category, AVG(price) FROM products GROUP BY category;
-- HAVING (filter groups)
SELECT department, AVG(salary)
FROM employees
GROUP BY department
HAVING AVG(salary) > 50000;
Indexes
-- Create index for faster queries
CREATE INDEX idx_email ON users(email);
CREATE INDEX idx_user_date ON orders(user_id, created_at);
-- Drop index
DROP INDEX idx_email;
Transactions
BEGIN TRANSACTION;
UPDATE accounts SET balance = balance - 100 WHERE id = 1;
UPDATE accounts SET balance = balance + 100 WHERE id = 2;
COMMIT; -- Save changes
-- ROLLBACK; -- Undo changes
Window Functions
-- Rank rows
SELECT name, salary,
RANK() OVER (ORDER BY salary DESC) as rank
FROM employees;
-- Running total
SELECT date, amount,
SUM(amount) OVER (ORDER BY date) as running_total
FROM transactions;
Common Patterns
Duplicate Finding
SELECT email, COUNT(*) FROM users GROUP BY email HAVING COUNT(*) > 1;
Top N per Group
SELECT DISTINCT ON (department) name, salary, department
FROM employees ORDER BY department, salary DESC;
Data Validation
SELECT * FROM users WHERE email NOT LIKE '%@%.%';
Performance Tips
- Use indexes on frequently queried columns
- EXPLAIN query plans:
EXPLAIN SELECT ... - **Avoid SELECT *** - specify columns needed
- Use LIMIT for large result sets
- Batch operations instead of individual queries
ACID Properties
- Atomicity: All or nothing
- Consistency: Valid state to valid state
- Isolation: Concurrent transactions independent
- Durability: Committed data survives failures
ELI10
SQL is like a filing system for data:
- SELECT: "Show me these files"
- INSERT: "Add new file"
- UPDATE: "Modify existing file"
- DELETE: "Remove file"
Joins = combining data from multiple filing cabinets!
Further Resources
PostgreSQL
PostgreSQL is a powerful, open-source object-relational database system with over 35 years of active development. It's known for its reliability, feature robustness, and performance.
Installation
# Ubuntu/Debian
sudo apt update
sudo apt install postgresql postgresql-contrib
# macOS
brew install postgresql@15
brew services start postgresql@15
# CentOS/RHEL
sudo yum install postgresql-server postgresql-contrib
sudo postgresql-setup initdb
sudo systemctl start postgresql
# Check version
psql --version
Basic Usage
# Connect as postgres user
sudo -u postgres psql
# Connect to specific database
psql -U username -d database_name
# Connect to remote database
psql -h hostname -U username -d database_name
# Execute SQL file
psql -U username -d database_name -f script.sql
# Execute command from shell
psql -U username -d database_name -c "SELECT * FROM users;"
Database Operations
-- Create database
CREATE DATABASE mydb;
-- List databases
\l
\list
-- Connect to database
\c mydb
\connect mydb
-- Drop database
DROP DATABASE mydb;
-- Create database with options
CREATE DATABASE mydb
WITH OWNER = myuser
ENCODING = 'UTF8'
LC_COLLATE = 'en_US.UTF-8'
LC_CTYPE = 'en_US.UTF-8'
TEMPLATE = template0;
Table Operations
-- Create table
CREATE TABLE users (
id SERIAL PRIMARY KEY,
username VARCHAR(50) UNIQUE NOT NULL,
email VARCHAR(100) UNIQUE NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- List tables
\dt
\dt+ -- with sizes
-- Describe table
\d users
\d+ users -- detailed
-- Drop table
DROP TABLE users;
DROP TABLE IF EXISTS users;
-- Alter table
ALTER TABLE users ADD COLUMN age INTEGER;
ALTER TABLE users DROP COLUMN age;
ALTER TABLE users RENAME COLUMN username TO user_name;
ALTER TABLE users ALTER COLUMN email SET NOT NULL;
CRUD Operations
-- Insert
INSERT INTO users (username, email)
VALUES ('john', 'john@example.com');
-- Insert multiple
INSERT INTO users (username, email) VALUES
('alice', 'alice@example.com'),
('bob', 'bob@example.com');
-- Insert with RETURNING
INSERT INTO users (username, email)
VALUES ('jane', 'jane@example.com')
RETURNING id, username;
-- Select
SELECT * FROM users;
SELECT username, email FROM users WHERE id = 1;
SELECT * FROM users WHERE username LIKE 'jo%';
SELECT * FROM users ORDER BY created_at DESC LIMIT 10;
-- Update
UPDATE users SET email = 'newemail@example.com' WHERE id = 1;
UPDATE users SET email = 'newemail@example.com' WHERE id = 1 RETURNING *;
-- Delete
DELETE FROM users WHERE id = 1;
DELETE FROM users WHERE created_at < '2023-01-01';
Indexes
-- Create index
CREATE INDEX idx_users_username ON users(username);
CREATE INDEX idx_users_email ON users(email);
-- Unique index
CREATE UNIQUE INDEX idx_users_username_unique ON users(username);
-- Composite index
CREATE INDEX idx_users_name_email ON users(username, email);
-- Partial index
CREATE INDEX idx_active_users ON users(username) WHERE active = true;
-- Full-text search index
CREATE INDEX idx_users_fulltext ON users USING GIN(to_tsvector('english', username || ' ' || email));
-- List indexes
\di
SELECT * FROM pg_indexes WHERE tablename = 'users';
-- Drop index
DROP INDEX idx_users_username;
Constraints
-- Primary key
CREATE TABLE products (
id SERIAL PRIMARY KEY,
name VARCHAR(100)
);
-- Foreign key
CREATE TABLE orders (
id SERIAL PRIMARY KEY,
user_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
product_id INTEGER REFERENCES products(id)
);
-- Unique constraint
ALTER TABLE users ADD CONSTRAINT users_email_unique UNIQUE (email);
-- Check constraint
ALTER TABLE products ADD CONSTRAINT products_price_positive
CHECK (price > 0);
-- Not null
ALTER TABLE users ALTER COLUMN email SET NOT NULL;
-- Default
ALTER TABLE users ALTER COLUMN active SET DEFAULT true;
Joins
-- Inner join
SELECT u.username, o.id AS order_id
FROM users u
INNER JOIN orders o ON u.id = o.user_id;
-- Left join
SELECT u.username, o.id AS order_id
FROM users u
LEFT JOIN orders o ON u.id = o.user_id;
-- Right join
SELECT u.username, o.id AS order_id
FROM users u
RIGHT JOIN orders o ON u.id = o.user_id;
-- Full outer join
SELECT u.username, o.id AS order_id
FROM users u
FULL OUTER JOIN orders o ON u.id = o.user_id;
-- Self join
SELECT e1.name AS employee, e2.name AS manager
FROM employees e1
LEFT JOIN employees e2 ON e1.manager_id = e2.id;
Aggregations
-- Count
SELECT COUNT(*) FROM users;
SELECT COUNT(DISTINCT email) FROM users;
-- Sum, Avg, Min, Max
SELECT
COUNT(*) AS total_orders,
SUM(amount) AS total_amount,
AVG(amount) AS avg_amount,
MIN(amount) AS min_amount,
MAX(amount) AS max_amount
FROM orders;
-- Group by
SELECT user_id, COUNT(*) AS order_count
FROM orders
GROUP BY user_id;
-- Having
SELECT user_id, COUNT(*) AS order_count
FROM orders
GROUP BY user_id
HAVING COUNT(*) > 5;
-- Window functions
SELECT
username,
created_at,
ROW_NUMBER() OVER (ORDER BY created_at) AS row_num,
RANK() OVER (ORDER BY created_at) AS rank,
LAG(created_at) OVER (ORDER BY created_at) AS prev_created
FROM users;
Transactions
-- Begin transaction
BEGIN;
INSERT INTO users (username, email) VALUES ('test', 'test@example.com');
UPDATE accounts SET balance = balance - 100 WHERE id = 1;
UPDATE accounts SET balance = balance + 100 WHERE id = 2;
-- Commit
COMMIT;
-- Rollback
ROLLBACK;
-- Savepoint
BEGIN;
INSERT INTO users (username, email) VALUES ('test', 'test@example.com');
SAVEPOINT my_savepoint;
UPDATE users SET email = 'new@example.com' WHERE username = 'test';
ROLLBACK TO my_savepoint;
COMMIT;
Views
-- Create view
CREATE VIEW active_users AS
SELECT id, username, email
FROM users
WHERE active = true;
-- Use view
SELECT * FROM active_users;
-- Materialized view
CREATE MATERIALIZED VIEW user_stats AS
SELECT
user_id,
COUNT(*) AS order_count,
SUM(amount) AS total_spent
FROM orders
GROUP BY user_id;
-- Refresh materialized view
REFRESH MATERIALIZED VIEW user_stats;
-- Drop view
DROP VIEW active_users;
DROP MATERIALIZED VIEW user_stats;
Functions and Procedures
-- Create function
CREATE OR REPLACE FUNCTION get_user_count()
RETURNS INTEGER AS $$
BEGIN
RETURN (SELECT COUNT(*) FROM users);
END;
$$ LANGUAGE plpgsql;
-- Call function
SELECT get_user_count();
-- Function with parameters
CREATE OR REPLACE FUNCTION get_user_by_id(user_id INTEGER)
RETURNS TABLE(username VARCHAR, email VARCHAR) AS $$
BEGIN
RETURN QUERY
SELECT u.username, u.email
FROM users u
WHERE u.id = user_id;
END;
$$ LANGUAGE plpgsql;
-- Call
SELECT * FROM get_user_by_id(1);
-- Procedure (PostgreSQL 11+)
CREATE OR REPLACE PROCEDURE add_user(
p_username VARCHAR,
p_email VARCHAR
)
LANGUAGE plpgsql AS $$
BEGIN
INSERT INTO users (username, email)
VALUES (p_username, p_email);
END;
$$;
-- Call procedure
CALL add_user('newuser', 'new@example.com');
Triggers
-- Create trigger function
CREATE OR REPLACE FUNCTION update_modified_column()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = NOW();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
-- Create trigger
CREATE TRIGGER update_users_modtime
BEFORE UPDATE ON users
FOR EACH ROW
EXECUTE FUNCTION update_modified_column();
-- List triggers
\dft
SELECT * FROM pg_trigger WHERE tgrelid = 'users'::regclass;
-- Drop trigger
DROP TRIGGER update_users_modtime ON users;
JSON Operations
-- JSON column
CREATE TABLE events (
id SERIAL PRIMARY KEY,
data JSONB
);
-- Insert JSON
INSERT INTO events (data) VALUES ('{"type": "click", "count": 1}');
-- Query JSON
SELECT data->>'type' AS event_type FROM events;
SELECT * FROM events WHERE data->>'type' = 'click';
SELECT * FROM events WHERE data->'count' > '5';
-- Update JSON
UPDATE events SET data = jsonb_set(data, '{count}', '10') WHERE id = 1;
-- JSON aggregation
SELECT jsonb_agg(username) FROM users;
SELECT jsonb_object_agg(id, username) FROM users;
Full-Text Search
-- Create tsvector column
ALTER TABLE articles ADD COLUMN textsearch tsvector;
-- Update tsvector
UPDATE articles SET textsearch =
to_tsvector('english', title || ' ' || body);
-- Create index
CREATE INDEX idx_articles_textsearch ON articles USING GIN(textsearch);
-- Search
SELECT title FROM articles
WHERE textsearch @@ to_tsquery('english', 'postgresql & performance');
-- Ranking
SELECT title, ts_rank(textsearch, query) AS rank
FROM articles, to_tsquery('english', 'postgresql') query
WHERE textsearch @@ query
ORDER BY rank DESC;
User Management
-- Create user
CREATE USER myuser WITH PASSWORD 'mypassword';
-- Create role
CREATE ROLE readonly;
-- Grant privileges
GRANT SELECT ON ALL TABLES IN SCHEMA public TO readonly;
GRANT ALL PRIVILEGES ON DATABASE mydb TO myuser;
GRANT SELECT, INSERT, UPDATE ON users TO myuser;
-- Revoke privileges
REVOKE INSERT ON users FROM myuser;
-- Alter user
ALTER USER myuser WITH PASSWORD 'newpassword';
ALTER USER myuser WITH SUPERUSER;
-- Drop user
DROP USER myuser;
-- List users
\du
SELECT * FROM pg_user;
Backup and Restore
# Dump database
pg_dump -U username -d mydb > mydb_backup.sql
pg_dump -U username -d mydb -F c > mydb_backup.dump
# Dump specific table
pg_dump -U username -d mydb -t users > users_backup.sql
# Dump all databases
pg_dumpall -U postgres > all_dbs.sql
# Restore from SQL file
psql -U username -d mydb < mydb_backup.sql
# Restore from custom format
pg_restore -U username -d mydb mydb_backup.dump
# Restore specific table
pg_restore -U username -d mydb -t users mydb_backup.dump
Performance Tuning
-- Analyze table
ANALYZE users;
-- Vacuum
VACUUM users;
VACUUM FULL users;
VACUUM ANALYZE users;
-- Explain query
EXPLAIN SELECT * FROM users WHERE username = 'john';
EXPLAIN ANALYZE SELECT * FROM users WHERE username = 'john';
-- Query statistics
SELECT * FROM pg_stat_user_tables WHERE relname = 'users';
SELECT * FROM pg_stat_user_indexes WHERE relname = 'users';
-- Active connections
SELECT * FROM pg_stat_activity;
-- Kill query
SELECT pg_cancel_backend(pid);
SELECT pg_terminate_backend(pid);
-- Table size
SELECT pg_size_pretty(pg_total_relation_size('users'));
Configuration
# postgresql.conf key settings
# Memory
shared_buffers = 256MB # 25% of RAM
effective_cache_size = 1GB # 50-75% of RAM
work_mem = 4MB
maintenance_work_mem = 64MB
# WAL
wal_buffers = 16MB
checkpoint_completion_target = 0.9
max_wal_size = 1GB
# Query planner
random_page_cost = 1.1 # For SSD
effective_io_concurrency = 200 # For SSD
# Connections
max_connections = 100
# Logging
log_destination = 'stderr'
logging_collector = on
log_directory = 'pg_log'
log_filename = 'postgresql-%Y-%m-%d_%H%M%S.log'
log_statement = 'all'
log_duration = on
log_min_duration_statement = 1000 # Log queries > 1s
psql Commands
# Meta-commands
\? # Help on psql commands
\h ALTER TABLE # Help on SQL command
\l # List databases
\c dbname # Connect to database
\dt # List tables
\dt+ # List tables with sizes
\d tablename # Describe table
\d+ tablename # Detailed table info
\di # List indexes
\dv # List views
\df # List functions
\du # List users
\dn # List schemas
\timing # Toggle timing
\x # Toggle expanded output
\q # Quit
\! command # Execute shell command
\i file.sql # Execute SQL file
\o file.txt # Output to file
\o # Output to stdout
Quick Reference
| Command | Description |
|---|---|
\l | List databases |
\c database | Connect to database |
\dt | List tables |
\d table | Describe table |
\di | List indexes |
\du | List users |
EXPLAIN | Show query plan |
VACUUM | Cleanup database |
pg_dump | Backup database |
psql -f file.sql | Execute SQL file |
PostgreSQL is a robust, feature-rich database system suitable for applications ranging from small projects to large-scale enterprise systems.
SQLite
SQLite is a C-language library that implements a small, fast, self-contained, high-reliability, full-featured SQL database engine. It's the most widely deployed database in the world.
Overview
SQLite is embedded into the application, requiring no separate server process. The entire database is stored in a single cross-platform file.
Key Features:
- Serverless, zero-configuration
- Self-contained (single file database)
- Cross-platform
- ACID compliant
- Supports most SQL standards
- Public domain (no license required)
Installation
# Ubuntu/Debian
sudo apt update
sudo apt install sqlite3
# macOS (pre-installed, or use Homebrew)
brew install sqlite
# CentOS/RHEL
sudo yum install sqlite
# Verify
sqlite3 --version
Basic Usage
# Create/open database
sqlite3 mydb.db
# Open existing database
sqlite3 existing.db
# Execute command from shell
sqlite3 mydb.db "SELECT * FROM users;"
# Execute SQL file
sqlite3 mydb.db < script.sql
# Dump database
sqlite3 mydb.db .dump > backup.sql
# Exit
.quit
.exit
Database Operations
-- Attach database
ATTACH DATABASE 'other.db' AS other;
-- List databases
.databases
-- Detach
DETACH DATABASE other;
-- Backup database
.backup backup.db
-- Restore from backup
.restore backup.db
Table Operations
-- Create table
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
email TEXT NOT NULL UNIQUE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
-- List tables
.tables
.schema
-- Show table schema
.schema users
PRAGMA table_info(users);
-- Drop table
DROP TABLE users;
DROP TABLE IF EXISTS users;
-- Rename table
ALTER TABLE users RENAME TO customers;
-- Add column
ALTER TABLE users ADD COLUMN age INTEGER;
-- Rename column (SQLite 3.25.0+)
ALTER TABLE users RENAME COLUMN username TO user_name;
-- Drop column (SQLite 3.35.0+)
ALTER TABLE users DROP COLUMN age;
Data Types
-- SQLite has 5 storage classes
-- INTEGER, REAL, TEXT, BLOB, NULL
CREATE TABLE examples (
int_col INTEGER,
real_col REAL,
text_col TEXT,
blob_col BLOB,
-- Type affinity examples
bool_col BOOLEAN, -- Stored as INTEGER (0 or 1)
date_col DATE, -- Stored as TEXT, INTEGER, or REAL
datetime_col DATETIME,
varchar_col VARCHAR(100), -- Stored as TEXT
decimal_col DECIMAL(10,2) -- Stored as REAL or TEXT
);
CRUD Operations
-- Insert
INSERT INTO users (username, email)
VALUES ('john', 'john@example.com');
-- Insert multiple
INSERT INTO users (username, email) VALUES
('alice', 'alice@example.com'),
('bob', 'bob@example.com');
-- Insert or replace
INSERT OR REPLACE INTO users (id, username, email)
VALUES (1, 'john', 'newemail@example.com');
-- Insert or ignore
INSERT OR IGNORE INTO users (username, email)
VALUES ('john', 'john@example.com');
-- Select
SELECT * FROM users;
SELECT username, email FROM users WHERE id = 1;
SELECT * FROM users WHERE username LIKE 'jo%';
SELECT * FROM users ORDER BY created_at DESC LIMIT 10;
SELECT * FROM users LIMIT 10 OFFSET 20;
-- Update
UPDATE users SET email = 'newemail@example.com' WHERE id = 1;
-- Delete
DELETE FROM users WHERE id = 1;
DELETE FROM users WHERE created_at < '2023-01-01';
Indexes
-- Create index
CREATE INDEX idx_users_username ON users(username);
CREATE INDEX idx_users_email ON users(email);
-- Unique index
CREATE UNIQUE INDEX idx_users_username_unique ON users(username);
-- Composite index
CREATE INDEX idx_users_name_email ON users(username, email);
-- Partial index
CREATE INDEX idx_active_users ON users(username) WHERE active = 1;
-- Expression index
CREATE INDEX idx_users_lower_username ON users(LOWER(username));
-- List indexes
.indexes
.indexes users
PRAGMA index_list(users);
-- Show index info
PRAGMA index_info(idx_users_username);
-- Drop index
DROP INDEX idx_users_username;
Constraints
-- Primary key
CREATE TABLE products (
id INTEGER PRIMARY KEY, -- Alias for rowid
name TEXT NOT NULL
);
-- Composite primary key
CREATE TABLE order_items (
order_id INTEGER,
product_id INTEGER,
quantity INTEGER,
PRIMARY KEY (order_id, product_id)
);
-- Foreign key (must enable)
PRAGMA foreign_keys = ON;
CREATE TABLE orders (
id INTEGER PRIMARY KEY,
user_id INTEGER,
FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
);
-- Unique constraint
CREATE TABLE users (
id INTEGER PRIMARY KEY,
email TEXT UNIQUE
);
-- Check constraint
CREATE TABLE products (
id INTEGER PRIMARY KEY,
price REAL CHECK(price > 0),
quantity INTEGER CHECK(quantity >= 0)
);
-- Not null
CREATE TABLE users (
id INTEGER PRIMARY KEY,
username TEXT NOT NULL
);
-- Default value
CREATE TABLE users (
id INTEGER PRIMARY KEY,
active INTEGER DEFAULT 1,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
Joins
-- Inner join
SELECT u.username, o.id AS order_id
FROM users u
INNER JOIN orders o ON u.id = o.user_id;
-- Left join
SELECT u.username, o.id AS order_id
FROM users u
LEFT JOIN orders o ON u.id = o.user_id;
-- Cross join
SELECT u.username, p.name
FROM users u
CROSS JOIN products p;
-- Natural join (not recommended)
SELECT * FROM users NATURAL JOIN orders;
Aggregations
-- Count
SELECT COUNT(*) FROM users;
SELECT COUNT(DISTINCT email) FROM users;
-- Sum, Avg, Min, Max
SELECT
COUNT(*) AS total_orders,
SUM(amount) AS total_amount,
AVG(amount) AS avg_amount,
MIN(amount) AS min_amount,
MAX(amount) AS max_amount
FROM orders;
-- Group by
SELECT user_id, COUNT(*) AS order_count
FROM orders
GROUP BY user_id;
-- Having
SELECT user_id, COUNT(*) AS order_count
FROM orders
GROUP BY user_id
HAVING COUNT(*) > 5;
-- Group concat
SELECT user_id, GROUP_CONCAT(product_name, ', ') AS products
FROM order_items
GROUP BY user_id;
Transactions
-- Begin transaction
BEGIN TRANSACTION;
INSERT INTO users (username, email) VALUES ('test', 'test@example.com');
UPDATE accounts SET balance = balance - 100 WHERE id = 1;
UPDATE accounts SET balance = balance + 100 WHERE id = 2;
-- Commit
COMMIT;
-- Rollback
ROLLBACK;
-- Transaction modes
BEGIN DEFERRED TRANSACTION; -- Default
BEGIN IMMEDIATE TRANSACTION; -- Acquire write lock
BEGIN EXCLUSIVE TRANSACTION; -- Exclusive access
-- Savepoint
BEGIN;
INSERT INTO users (username, email) VALUES ('test', 'test@example.com');
SAVEPOINT sp1;
UPDATE users SET email = 'new@example.com' WHERE username = 'test';
ROLLBACK TO sp1;
COMMIT;
Views
-- Create view
CREATE VIEW active_users AS
SELECT id, username, email
FROM users
WHERE active = 1;
-- Use view
SELECT * FROM active_users;
-- Temporary view
CREATE TEMP VIEW temp_users AS
SELECT * FROM users WHERE created_at > date('now', '-7 days');
-- Drop view
DROP VIEW active_users;
Triggers
-- Before insert trigger
CREATE TRIGGER validate_email
BEFORE INSERT ON users
BEGIN
SELECT CASE
WHEN NEW.email NOT LIKE '%@%' THEN
RAISE(ABORT, 'Invalid email format')
END;
END;
-- After insert trigger
CREATE TRIGGER log_user_creation
AFTER INSERT ON users
BEGIN
INSERT INTO audit_log (table_name, action, timestamp)
VALUES ('users', 'INSERT', datetime('now'));
END;
-- Update trigger
CREATE TRIGGER update_modified_time
AFTER UPDATE ON users
BEGIN
UPDATE users SET updated_at = datetime('now')
WHERE id = NEW.id;
END;
-- Instead of trigger (for views)
CREATE TRIGGER update_active_users
INSTEAD OF UPDATE ON active_users
BEGIN
UPDATE users SET email = NEW.email WHERE id = NEW.id;
END;
-- List triggers
.schema users
SELECT * FROM sqlite_master WHERE type = 'trigger';
-- Drop trigger
DROP TRIGGER validate_email;
Date and Time
-- Current date/time
SELECT date('now'); -- 2024-01-15
SELECT time('now'); -- 14:30:45
SELECT datetime('now'); -- 2024-01-15 14:30:45
SELECT strftime('%Y-%m-%d %H:%M', 'now');
-- Date arithmetic
SELECT date('now', '+7 days');
SELECT date('now', '-1 month');
SELECT datetime('now', '+5 hours');
SELECT date('now', 'start of month');
SELECT date('now', 'start of year');
-- Extract parts
SELECT strftime('%Y', 'now') AS year;
SELECT strftime('%m', 'now') AS month;
SELECT strftime('%d', 'now') AS day;
SELECT strftime('%H', 'now') AS hour;
-- Julian day
SELECT julianday('now');
SELECT julianday('now') - julianday('2024-01-01');
-- Unix timestamp
SELECT strftime('%s', 'now'); -- Unix timestamp
SELECT datetime(1234567890, 'unixepoch'); -- From timestamp
JSON Operations (SQLite 3.38.0+)
-- JSON functions
SELECT json('{"name":"John","age":30}');
-- Extract value
SELECT json_extract('{"name":"John","age":30}', '$.name');
SELECT '{"name":"John","age":30}' -> 'name'; -- Shorthand
-- Array operations
SELECT json_each.value
FROM json_each('[1,2,3,4,5]');
-- Store JSON
CREATE TABLE events (
id INTEGER PRIMARY KEY,
data TEXT
);
INSERT INTO events (data) VALUES ('{"type":"click","count":1}');
-- Query JSON
SELECT * FROM events
WHERE json_extract(data, '$.type') = 'click';
-- Update JSON
UPDATE events
SET data = json_set(data, '$.count', json_extract(data, '$.count') + 1)
WHERE id = 1;
Full-Text Search
-- Create FTS5 table
CREATE VIRTUAL TABLE articles_fts USING fts5(
title,
body,
content=articles,
content_rowid=id
);
-- Populate FTS table
INSERT INTO articles_fts(rowid, title, body)
SELECT id, title, body FROM articles;
-- Search
SELECT * FROM articles_fts WHERE articles_fts MATCH 'sqlite performance';
-- Ranking
SELECT *, rank FROM articles_fts
WHERE articles_fts MATCH 'sqlite'
ORDER BY rank;
-- Phrase search
SELECT * FROM articles_fts WHERE articles_fts MATCH '"sqlite database"';
-- Column-specific search
SELECT * FROM articles_fts WHERE title MATCH 'tutorial';
Pragma Statements
-- Database info
PRAGMA database_list;
PRAGMA table_info(users);
PRAGMA index_list(users);
PRAGMA foreign_key_list(orders);
-- Performance
PRAGMA cache_size = 10000; -- Pages in cache
PRAGMA page_size = 4096; -- Page size in bytes
PRAGMA journal_mode = WAL; -- Write-Ahead Logging
PRAGMA synchronous = NORMAL; -- Sync mode
PRAGMA temp_store = MEMORY; -- Temp tables in memory
-- Foreign keys
PRAGMA foreign_keys = ON;
PRAGMA foreign_keys; -- Check status
-- Integrity check
PRAGMA integrity_check;
PRAGMA quick_check;
-- Database size
PRAGMA page_count;
PRAGMA page_size;
-- Total size = page_count * page_size
-- Optimization
PRAGMA optimize;
VACUUM;
Performance Optimization
-- Enable WAL mode (Write-Ahead Logging)
PRAGMA journal_mode = WAL;
-- Increase cache size
PRAGMA cache_size = -64000; -- 64MB
-- Disable synchronous (faster but less safe)
PRAGMA synchronous = OFF;
PRAGMA synchronous = NORMAL; -- Balanced
-- Analyze tables
ANALYZE;
ANALYZE users;
-- Vacuum database
VACUUM;
-- Batch inserts
BEGIN TRANSACTION;
-- Multiple INSERT statements
COMMIT;
-- Use prepared statements (in code)
-- Better performance and security
-- Indexes for frequently queried columns
CREATE INDEX idx_users_email ON users(email);
Backup and Recovery
# Backup database
sqlite3 mydb.db ".backup backup.db"
sqlite3 mydb.db .dump > backup.sql
cp mydb.db mydb_backup.db # Simple copy
# Restore from backup
sqlite3 newdb.db ".restore backup.db"
sqlite3 newdb.db < backup.sql
# Export to CSV
.mode csv
.output users.csv
SELECT * FROM users;
.output stdout
# Import from CSV
.mode csv
.import users.csv users
SQLite CLI Commands
# Meta-commands
.help # Show help
.databases # List databases
.tables # List tables
.schema # Show all schemas
.schema users # Show table schema
.indexes users # Show indexes
.mode column # Column output mode
.mode csv # CSV output mode
.mode json # JSON output mode
.headers on # Show column headers
.width 10 20 30 # Set column widths
.output file.txt # Output to file
.output stdout # Output to screen
.read file.sql # Execute SQL file
.timer on # Show execution time
.quit # Exit
Common Patterns
-- Upsert (Insert or Update)
INSERT INTO users (id, username, email)
VALUES (1, 'john', 'john@example.com')
ON CONFLICT(id) DO UPDATE SET
username = excluded.username,
email = excluded.email;
-- Conditional insert
INSERT INTO users (username, email)
SELECT 'john', 'john@example.com'
WHERE NOT EXISTS (
SELECT 1 FROM users WHERE username = 'john'
);
-- Auto-increment
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT
);
-- Get last insert rowid
SELECT last_insert_rowid();
-- Pagination
SELECT * FROM users
ORDER BY id
LIMIT 10 OFFSET 20;
-- Random row
SELECT * FROM users ORDER BY RANDOM() LIMIT 1;
Best Practices
-- 1. Enable foreign keys
PRAGMA foreign_keys = ON;
-- 2. Use WAL mode for better concurrency
PRAGMA journal_mode = WAL;
-- 3. Use transactions for bulk operations
BEGIN TRANSACTION;
-- Multiple operations
COMMIT;
-- 4. Create indexes for frequently queried columns
CREATE INDEX idx_users_email ON users(email);
-- 5. Use INTEGER PRIMARY KEY for auto-increment
CREATE TABLE users (
id INTEGER PRIMARY KEY,
username TEXT
);
-- 6. Analyze database periodically
ANALYZE;
-- 7. Use prepared statements in code
-- Prevents SQL injection and improves performance
-- 8. Vacuum database periodically
VACUUM;
-- 9. Use appropriate data types
-- SQLite is flexible but using correct types helps
-- 10. Regular backups
-- Use .backup command or copy the database file
Quick Reference
| Command | Description |
|---|---|
.tables | List tables |
.schema TABLE | Show table structure |
.mode column | Set output format |
.headers on | Show column headers |
.backup FILE | Backup database |
.import FILE TABLE | Import CSV |
PRAGMA foreign_keys=ON | Enable foreign keys |
PRAGMA journal_mode=WAL | Enable WAL mode |
VACUUM | Optimize database |
ANALYZE | Update statistics |
SQLite is ideal for embedded systems, mobile apps, desktop applications, and scenarios where a simple, reliable, serverless database is needed.
DuckDB
DuckDB is an in-process SQL OLAP (Online Analytical Processing) database management system designed for analytical query workloads. It's often described as "SQLite for analytics."
Overview
DuckDB is optimized for analytical queries with columnar storage, vectorized execution, and minimal dependencies.
Key Features:
- In-process, embedded database
- Columnar storage for analytics
- ACID compliant
- Vectorized query execution
- No external dependencies
- SQL compatible
- Direct querying of CSV, Parquet, JSON
Installation
# Ubuntu/Debian
sudo apt install duckdb
# macOS
brew install duckdb
# Python
pip install duckdb
# From binary
wget https://github.com/duckdb/duckdb/releases/download/v0.9.2/duckdb_cli-linux-amd64.zip
unzip duckdb_cli-linux-amd64.zip
sudo mv duckdb /usr/local/bin/
# Verify
duckdb --version
Basic Usage
# Start DuckDB CLI
duckdb
# Create/open database file
duckdb mydb.duckdb
# In-memory database
duckdb :memory:
# Execute command from shell
duckdb mydb.duckdb "SELECT * FROM users;"
# Execute SQL file
duckdb mydb.duckdb < script.sql
# Exit
.quit
Python API
import duckdb
# Connect to database
con = duckdb.connect('mydb.duckdb')
# In-memory database
con = duckdb.connect(':memory:')
# Execute query
result = con.execute("SELECT * FROM users").fetchall()
# Fetch as DataFrame
df = con.execute("SELECT * FROM users").df()
# Direct DataFrame query
import pandas as pd
df = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
result = duckdb.query("SELECT * FROM df WHERE a > 1").df()
# Close connection
con.close()
Table Operations
-- Create table
CREATE TABLE users (
id INTEGER PRIMARY KEY,
username VARCHAR,
email VARCHAR,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- Create table from query
CREATE TABLE new_users AS
SELECT * FROM users WHERE created_at > '2024-01-01';
-- Show tables
SHOW TABLES;
.tables
-- Describe table
DESCRIBE users;
.schema users
-- Drop table
DROP TABLE users;
Reading External Files
-- Read CSV
SELECT * FROM read_csv_auto('data.csv');
-- Read CSV with options
SELECT * FROM read_csv('data.csv',
header=true,
delim=',',
quote='"',
types={'id': 'INTEGER', 'name': 'VARCHAR'}
);
-- Create table from CSV
CREATE TABLE users AS
SELECT * FROM read_csv_auto('users.csv');
-- Read Parquet
SELECT * FROM read_parquet('data.parquet');
SELECT * FROM 'data.parquet'; -- Shorthand
-- Read multiple Parquet files
SELECT * FROM read_parquet(['file1.parquet', 'file2.parquet']);
SELECT * FROM read_parquet('data/*.parquet');
-- Read JSON
SELECT * FROM read_json_auto('data.json');
SELECT * FROM read_json('data.json', format='array');
-- Read JSON lines
SELECT * FROM read_json_auto('data.jsonl', format='newline_delimited');
Writing to Files
-- Export to CSV
COPY users TO 'users.csv' (HEADER, DELIMITER ',');
-- Export to Parquet
COPY users TO 'users.parquet' (FORMAT PARQUET);
-- Export query result
COPY (SELECT * FROM users WHERE active = true)
TO 'active_users.parquet' (FORMAT PARQUET);
-- Export to JSON
COPY users TO 'users.json';
CRUD Operations
-- Insert
INSERT INTO users (username, email)
VALUES ('john', 'john@example.com');
-- Insert multiple
INSERT INTO users (username, email) VALUES
('alice', 'alice@example.com'),
('bob', 'bob@example.com');
-- Insert from SELECT
INSERT INTO users (username, email)
SELECT username, email FROM temp_users;
-- Select
SELECT * FROM users;
SELECT * FROM users WHERE username LIKE 'jo%';
SELECT * FROM users ORDER BY created_at DESC LIMIT 10;
-- Update
UPDATE users SET email = 'newemail@example.com' WHERE id = 1;
-- Delete
DELETE FROM users WHERE id = 1;
Analytical Queries
-- Window functions
SELECT
username,
created_at,
ROW_NUMBER() OVER (ORDER BY created_at) AS row_num,
RANK() OVER (ORDER BY created_at) AS rank,
DENSE_RANK() OVER (ORDER BY created_at) AS dense_rank,
NTILE(4) OVER (ORDER BY created_at) AS quartile
FROM users;
-- Moving average
SELECT
date,
revenue,
AVG(revenue) OVER (ORDER BY date ROWS BETWEEN 6 PRECEDING AND CURRENT ROW) AS moving_avg_7d
FROM sales;
-- Cumulative sum
SELECT
date,
amount,
SUM(amount) OVER (ORDER BY date) AS cumulative_total
FROM transactions;
-- Percent rank
SELECT
username,
score,
PERCENT_RANK() OVER (ORDER BY score) AS percentile
FROM scores;
Aggregations
-- Basic aggregations
SELECT
COUNT(*) AS total,
COUNT(DISTINCT user_id) AS unique_users,
SUM(amount) AS total_amount,
AVG(amount) AS avg_amount,
MIN(amount) AS min_amount,
MAX(amount) AS max_amount,
STDDEV(amount) AS std_dev,
MEDIAN(amount) AS median_amount
FROM orders;
-- Group by with ROLLUP
SELECT
category,
subcategory,
SUM(amount) AS total
FROM sales
GROUP BY ROLLUP (category, subcategory);
-- Group by with CUBE
SELECT
region,
product,
SUM(revenue) AS total
FROM sales
GROUP BY CUBE (region, product);
-- GROUPING SETS
SELECT
region,
product,
SUM(revenue) AS total
FROM sales
GROUP BY GROUPING SETS ((region), (product), ());
Time Series
-- Generate date series
SELECT * FROM generate_series(
TIMESTAMP '2024-01-01',
TIMESTAMP '2024-12-31',
INTERVAL '1 day'
) AS t(date);
-- Time bucket
SELECT
time_bucket(INTERVAL '1 hour', timestamp) AS hour,
COUNT(*) AS events,
AVG(value) AS avg_value
FROM events
GROUP BY hour
ORDER BY hour;
-- Date truncation
SELECT
date_trunc('month', created_at) AS month,
COUNT(*) AS user_count
FROM users
GROUP BY month;
-- Extract date parts
SELECT
EXTRACT(year FROM created_at) AS year,
EXTRACT(month FROM created_at) AS month,
EXTRACT(day FROM created_at) AS day,
EXTRACT(hour FROM created_at) AS hour
FROM events;
Joins
-- Inner join
SELECT u.username, o.amount
FROM users u
INNER JOIN orders o ON u.id = o.user_id;
-- Left join
SELECT u.username, o.amount
FROM users u
LEFT JOIN orders o ON u.id = o.user_id;
-- Right join
SELECT u.username, o.amount
FROM users u
RIGHT JOIN orders o ON u.id = o.user_id;
-- Full outer join
SELECT u.username, o.amount
FROM users u
FULL OUTER JOIN orders o ON u.id = o.user_id;
-- Cross join
SELECT u.username, p.name
FROM users u
CROSS JOIN products p;
-- Join with USING
SELECT * FROM users u
JOIN orders o USING (user_id);
-- ASOF join (temporal join)
SELECT * FROM trades
ASOF JOIN quotes
ON trades.symbol = quotes.symbol
AND trades.timestamp >= quotes.timestamp;
Common Table Expressions (CTEs)
-- Basic CTE
WITH active_users AS (
SELECT * FROM users WHERE active = true
)
SELECT * FROM active_users WHERE created_at > '2024-01-01';
-- Multiple CTEs
WITH
active_users AS (
SELECT * FROM users WHERE active = true
),
recent_orders AS (
SELECT * FROM orders WHERE created_at > '2024-01-01'
)
SELECT u.username, COUNT(o.id) AS order_count
FROM active_users u
LEFT JOIN recent_orders o ON u.id = o.user_id
GROUP BY u.username;
-- Recursive CTE
WITH RECURSIVE countdown(n) AS (
SELECT 10 AS n
UNION ALL
SELECT n - 1 FROM countdown WHERE n > 1
)
SELECT * FROM countdown;
Pivot and Unpivot
-- Pivot
PIVOT sales
ON product_category
USING SUM(amount)
GROUP BY region;
-- Manual pivot
SELECT
region,
SUM(CASE WHEN category = 'Electronics' THEN amount ELSE 0 END) AS electronics,
SUM(CASE WHEN category = 'Clothing' THEN amount ELSE 0 END) AS clothing,
SUM(CASE WHEN category = 'Food' THEN amount ELSE 0 END) AS food
FROM sales
GROUP BY region;
-- Unpivot
UNPIVOT sales
ON electronics, clothing, food
INTO NAME category VALUE amount;
String Functions
-- String operations
SELECT
UPPER(username) AS upper_name,
LOWER(username) AS lower_name,
CONCAT(first_name, ' ', last_name) AS full_name,
SUBSTRING(email, 1, 5) AS email_prefix,
LENGTH(username) AS name_length,
REPLACE(email, '@gmail.com', '@example.com') AS new_email,
SPLIT_PART(email, '@', 1) AS email_user,
TRIM(username) AS trimmed,
REGEXP_MATCHES(text, '[0-9]+') AS numbers,
REGEXP_REPLACE(text, '[0-9]', 'X') AS masked
FROM users;
-- String aggregation
SELECT
category,
STRING_AGG(product_name, ', ') AS products
FROM products
GROUP BY category;
-- List functions
SELECT
LIST(['a', 'b', 'c']) AS my_list,
LIST_VALUE('a', 'b', 'c') AS another_list,
[1, 2, 3] AS numeric_list;
SELECT list[1] FROM (SELECT [1, 2, 3] AS list);
Array and Struct Operations
-- Arrays
SELECT [1, 2, 3, 4, 5] AS numbers;
SELECT LIST_VALUE(1, 2, 3, 4, 5) AS numbers;
SELECT UNNEST([1, 2, 3]) AS num;
-- Array aggregation
SELECT LIST(username) AS all_users FROM users;
-- Struct
SELECT {'name': 'John', 'age': 30} AS person;
SELECT person.name FROM (SELECT {'name': 'John', 'age': 30} AS person);
-- Nested structures
SELECT {
'user': {'name': 'John', 'email': 'john@example.com'},
'orders': [1, 2, 3]
} AS complex_data;
Constraints and Indexes
-- Primary key
CREATE TABLE users (
id INTEGER PRIMARY KEY,
username VARCHAR UNIQUE NOT NULL
);
-- Check constraint
CREATE TABLE products (
id INTEGER PRIMARY KEY,
price DECIMAL CHECK (price > 0),
quantity INTEGER CHECK (quantity >= 0)
);
-- Create index
CREATE INDEX idx_users_email ON users(email);
-- Drop index
DROP INDEX idx_users_email;
-- Show indexes
PRAGMA show_index('users');
Transactions
-- Begin transaction
BEGIN TRANSACTION;
INSERT INTO users (username, email) VALUES ('test', 'test@example.com');
UPDATE accounts SET balance = balance - 100 WHERE id = 1;
-- Commit
COMMIT;
-- Rollback
ROLLBACK;
Views
-- Create view
CREATE VIEW active_users AS
SELECT id, username, email
FROM users
WHERE active = true;
-- Use view
SELECT * FROM active_users;
-- Drop view
DROP VIEW active_users;
Performance Optimization
-- Analyze query plan
EXPLAIN SELECT * FROM users WHERE username = 'john';
EXPLAIN ANALYZE SELECT * FROM users JOIN orders ON users.id = orders.user_id;
-- Vacuum and analyze
VACUUM;
ANALYZE users;
-- Parallel query execution (automatic)
SET threads TO 4;
-- Memory limit
SET memory_limit = '4GB';
-- Temp directory
SET temp_directory = '/path/to/temp';
Settings and Configuration
-- Show settings
SELECT * FROM duckdb_settings();
-- Set configuration
SET memory_limit = '8GB';
SET threads TO 8;
SET max_memory = '16GB';
SET temp_directory = '/tmp';
-- Progress bar
SET enable_progress_bar = true;
-- Profiling
SET enable_profiling = true;
SET profiling_mode = 'detailed';
Importing from Other Databases
-- Attach SQLite database
ATTACH 'mydb.sqlite' AS sqlite_db (TYPE SQLITE);
SELECT * FROM sqlite_db.users;
-- Attach PostgreSQL
ATTACH 'dbname=mydb user=postgres host=localhost' AS pg_db (TYPE POSTGRES);
SELECT * FROM pg_db.users;
-- Copy data
CREATE TABLE local_users AS
SELECT * FROM pg_db.users;
-- Detach
DETACH sqlite_db;
Python Integration
import duckdb
import pandas as pd
# Create connection
con = duckdb.connect('mydb.duckdb')
# Query to DataFrame
df = con.execute("SELECT * FROM users").df()
# Register DataFrame as table
con.register('df_users', df)
result = con.execute("SELECT * FROM df_users WHERE age > 30").df()
# Direct query on DataFrame
result = duckdb.query("SELECT * FROM df WHERE column_a > 10").df()
# Arrow integration
import pyarrow as pa
arrow_table = con.execute("SELECT * FROM users").arrow()
# Register Arrow table
con.register('arrow_users', arrow_table)
# Relation API
rel = con.table('users')
result = rel.filter('age > 30').project('username, email').df()
# Close
con.close()
CLI Commands
# Meta-commands
.help # Show help
.tables # List tables
.schema # Show all schemas
.schema users # Show table schema
.mode # Show output mode
.mode csv # Set CSV output
.mode json # Set JSON output
.mode markdown # Set Markdown output
.output file.csv # Output to file
.timer on # Show query timing
.maxrows 100 # Limit output rows
.quit # Exit
Best Practices
-- 1. Use columnar storage (Parquet) for large datasets
COPY large_table TO 'data.parquet' (FORMAT PARQUET);
-- 2. Leverage parallel execution
SET threads TO 8;
-- 3. Use appropriate data types
CREATE TABLE optimized (
id INTEGER,
name VARCHAR,
value DOUBLE,
date DATE
);
-- 4. Create indexes for frequently filtered columns
CREATE INDEX idx_users_email ON users(email);
-- 5. Use window functions instead of self-joins
SELECT username, LAG(score) OVER (ORDER BY date) AS prev_score
FROM scores;
-- 6. Partition large queries
SELECT * FROM large_table
WHERE date >= '2024-01-01' AND date < '2024-02-01';
-- 7. Use CTEs for readability
WITH filtered AS (SELECT * FROM users WHERE active = true)
SELECT * FROM filtered;
-- 8. Analyze queries for optimization
EXPLAIN ANALYZE SELECT * FROM complex_query;
-- 9. Read directly from files when possible
SELECT * FROM 'data.parquet' WHERE column > 100;
-- 10. Use appropriate compression
COPY data TO 'compressed.parquet' (FORMAT PARQUET, COMPRESSION ZSTD);
Quick Reference
| Command | Description |
|---|---|
read_csv_auto('file.csv') | Read CSV file |
read_parquet('file.parquet') | Read Parquet file |
COPY table TO 'file.csv' | Export to CSV |
EXPLAIN ANALYZE | Show query plan |
SET threads TO 8 | Set thread count |
DESCRIBE table | Show table schema |
.tables | List tables |
.mode csv | Set output format |
VACUUM | Optimize database |
ANALYZE | Update statistics |
DuckDB excels at analytical queries on local data files, making it perfect for data analysis, ETL pipelines, and embedded analytics applications.
NoSQL Databases
Overview
NoSQL databases store data in non-relational formats (documents, key-value, graph, etc.). Designed for scalability, flexibility, and high-performance.
Types
Document Databases (MongoDB)
// Insert
db.users.insertOne({ name: "John", age: 30, email: "john@example.com" });
// Find
db.users.findOne({ name: "John" });
db.users.find({ age: { $gt: 25 } });
// Update
db.users.updateOne({ _id: ObjectId("...") }, { $set: { age: 31 } });
// Delete
db.users.deleteOne({ name: "John" });
// Aggregation
db.users.aggregate([
{ $match: { age: { $gt: 25 } } },
{ $group: { _id: "$city", count: { $sum: 1 } } },
{ $sort: { count: -1 } }
]);
Key-Value Stores (Redis)
# Strings
SET key value
GET key
INCR counter
# Lists
LPUSH mylist "a" "b" "c"
LPOP mylist
LRANGE mylist 0 -1
# Sets
SADD myset "a" "b" "c"
SMEMBERS myset
# Hashes
HSET user:1 name "John" age 30
HGET user:1 name
HGETALL user:1
# Expiration
EXPIRE key 3600 # 1 hour TTL
Column-Family (Cassandra)
-- Wide, denormalized columns
CREATE TABLE users (
user_id UUID PRIMARY KEY,
name TEXT,
email TEXT,
created_at TIMESTAMP,
metadata MAP<TEXT, TEXT>
);
Graph Databases (Neo4j)
// Create
CREATE (n:Person {name: "John", age: 30})
CREATE (m:Company {name: "Acme"})
CREATE (n)-[:WORKS_AT]->(m)
// Query
MATCH (p:Person)-[:WORKS_AT]->(c:Company)
WHERE p.age > 25
RETURN p.name, c.name
// Find friends
MATCH (p:Person {name: "John"})-[:FRIEND*1..2]-(friend)
RETURN friend.name
CAP Theorem
Every distributed database trades off:
- Consistency: All nodes see same data
- Availability: System always responsive
- Partition Tolerance: Survive network splits
You can have 2 of 3:
- CP: Strong consistency, unavailable during partitions (Spanner)
- AP: Always available, eventual consistency (Dynamo, Cassandra)
- CA: Consistent and available, can't handle partitions (traditional DB)
Use Cases
| Database | Best For |
|---|---|
| MongoDB | Flexible schema, documents |
| Redis | Caching, sessions, real-time |
| Cassandra | Time-series, massive scale |
| Neo4j | Graph queries, relationships |
| Elasticsearch | Full-text search, logs |
Data Modeling
# Denormalization (NoSQL style)
# One document with all info
{
"_id": "user_1",
"name": "John",
"orders": [
{ "id": "order_1", "amount": 100 },
{ "id": "order_2", "amount": 200 }
]
}
# vs SQL (normalization)
# users table + orders table + JOIN
ELI10
NoSQL is like a flexible filing system:
- Document DB: Store complete documents (like PDF files)
- Key-Value: Simple lookup (like phone book)
- Graph: Show relationships (like social network)
- Column: Organize by columns not rows (like spreadsheet)
Trade flexibility and speed for less strict structure!
Further Resources
MongoDB
MongoDB is a popular NoSQL database that stores data in flexible, JSON-like documents. It's designed for scalability, high performance, and ease of development, making it ideal for modern applications that require flexible schema design and horizontal scaling.
Table of Contents
- Introduction
- Installation and Setup
- CRUD Operations
- Data Modeling
- Queries and Aggregation
- Indexing
- MongoDB with Node.js
- Best Practices
- Performance Optimization
Introduction
Key Features:
- Document-oriented storage (JSON/BSON)
- Flexible schema design
- High performance
- High availability (Replica Sets)
- Horizontal scalability (Sharding)
- Rich query language
- Aggregation framework
- GridFS for large files
- Change Streams for real-time data
Use Cases:
- Content management systems
- Real-time analytics
- IoT applications
- Mobile applications
- Catalogs and inventory
- User data management
- Caching layer
Installation and Setup
Install MongoDB
macOS:
brew tap mongodb/brew
brew install mongodb-community
brew services start mongodb-community
Ubuntu:
wget -qO - https://www.mongodb.org/static/pgp/server-6.0.asc | sudo apt-key add -
echo "deb [ arch=amd64,arm64 ] https://repo.mongodb.org/apt/ubuntu focal/mongodb-org/6.0 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb-org-6.0.list
sudo apt-get update
sudo apt-get install -y mongodb-org
sudo systemctl start mongod
Docker:
docker run -d -p 27017:27017 --name mongodb mongo:latest
MongoDB Shell
# Connect to MongoDB
mongosh
# Show databases
show dbs
# Use/create database
use mydb
# Show collections
show collections
# Exit
exit
CRUD Operations
Create (Insert)
// Insert one document
db.users.insertOne({
name: "John Doe",
email: "john@example.com",
age: 30,
createdAt: new Date()
})
// Insert multiple documents
db.users.insertMany([
{ name: "Jane Smith", email: "jane@example.com", age: 28 },
{ name: "Bob Johnson", email: "bob@example.com", age: 35 }
])
Read (Query)
// Find all documents
db.users.find()
// Find with filter
db.users.find({ age: { $gte: 30 } })
// Find one document
db.users.findOne({ email: "john@example.com" })
// Projection (select specific fields)
db.users.find({}, { name: 1, email: 1, _id: 0 })
// Limit and sort
db.users.find().limit(10).sort({ age: -1 })
// Count documents
db.users.countDocuments({ age: { $gte: 30 } })
Update
// Update one document
db.users.updateOne(
{ email: "john@example.com" },
{ $set: { age: 31, updatedAt: new Date() } }
)
// Update multiple documents
db.users.updateMany(
{ age: { $lt: 30 } },
{ $set: { status: "young" } }
)
// Replace document
db.users.replaceOne(
{ email: "john@example.com" },
{ name: "John Doe", email: "john@example.com", age: 31 }
)
// Upsert (update or insert)
db.users.updateOne(
{ email: "new@example.com" },
{ $set: { name: "New User", age: 25 } },
{ upsert: true }
)
// Increment field
db.users.updateOne(
{ email: "john@example.com" },
{ $inc: { loginCount: 1 } }
)
// Add to array
db.users.updateOne(
{ email: "john@example.com" },
{ $push: { hobbies: "reading" } }
)
Delete
// Delete one document
db.users.deleteOne({ email: "john@example.com" })
// Delete multiple documents
db.users.deleteMany({ age: { $lt: 18 } })
// Delete all documents
db.users.deleteMany({})
Data Modeling
Embedded Documents
// User with embedded address
db.users.insertOne({
name: "John Doe",
email: "john@example.com",
address: {
street: "123 Main St",
city: "New York",
state: "NY",
zip: "10001"
},
phoneNumbers: [
{ type: "home", number: "555-1234" },
{ type: "work", number: "555-5678" }
]
})
Document References
// Posts collection
db.posts.insertOne({
title: "My First Post",
content: "This is my first blog post",
authorId: ObjectId("user_id_here"),
comments: [
{
userId: ObjectId("commenter_id"),
text: "Great post!",
createdAt: new Date()
}
]
})
// Query with lookup
db.posts.aggregate([
{
$lookup: {
from: "users",
localField: "authorId",
foreignField: "_id",
as: "author"
}
}
])
Schema Design Patterns
// One-to-One (Embedded)
{
_id: ObjectId(),
username: "john_doe",
profile: {
firstName: "John",
lastName: "Doe",
bio: "Software developer"
}
}
// One-to-Many (Embedded - for small arrays)
{
_id: ObjectId(),
title: "Blog Post",
tags: ["mongodb", "database", "nosql"]
}
// One-to-Many (Referenced - for large collections)
{
_id: ObjectId(),
name: "Category",
products: [
ObjectId("product_1"),
ObjectId("product_2")
]
}
// Many-to-Many
// Users collection
{
_id: ObjectId("user_1"),
name: "John",
courseIds: [ObjectId("course_1"), ObjectId("course_2")]
}
// Courses collection
{
_id: ObjectId("course_1"),
title: "MongoDB Course",
studentIds: [ObjectId("user_1"), ObjectId("user_2")]
}
Queries and Aggregation
Query Operators
// Comparison operators
db.users.find({ age: { $eq: 30 } }) // Equal
db.users.find({ age: { $ne: 30 } }) // Not equal
db.users.find({ age: { $gt: 30 } }) // Greater than
db.users.find({ age: { $gte: 30 } }) // Greater than or equal
db.users.find({ age: { $lt: 30 } }) // Less than
db.users.find({ age: { $lte: 30 } }) // Less than or equal
db.users.find({ age: { $in: [25, 30, 35] } }) // In array
db.users.find({ age: { $nin: [25, 30] } }) // Not in array
// Logical operators
db.users.find({
$and: [
{ age: { $gte: 25 } },
{ age: { $lte: 35 } }
]
})
db.users.find({
$or: [
{ age: { $lt: 25 } },
{ age: { $gt: 35 } }
]
})
db.users.find({ age: { $not: { $gte: 30 } } })
// Element operators
db.users.find({ email: { $exists: true } })
db.users.find({ age: { $type: "number" } })
// Array operators
db.users.find({ hobbies: { $all: ["reading", "gaming"] } })
db.users.find({ hobbies: { $size: 3 } })
db.users.find({ "hobbies.0": "reading" })
// Text search
db.posts.createIndex({ title: "text", content: "text" })
db.posts.find({ $text: { $search: "mongodb tutorial" } })
Aggregation Pipeline
// Basic aggregation
db.orders.aggregate([
// Match stage (filter)
{ $match: { status: "completed" } },
// Group stage
{
$group: {
_id: "$customerId",
totalSpent: { $sum: "$amount" },
orderCount: { $sum: 1 },
avgOrderAmount: { $avg: "$amount" }
}
},
// Sort stage
{ $sort: { totalSpent: -1 } },
// Limit stage
{ $limit: 10 }
])
// Complex aggregation with lookup
db.orders.aggregate([
// Join with users collection
{
$lookup: {
from: "users",
localField: "userId",
foreignField: "_id",
as: "user"
}
},
// Unwind array
{ $unwind: "$user" },
// Project (reshape documents)
{
$project: {
orderNumber: 1,
amount: 1,
userName: "$user.name",
userEmail: "$user.email"
}
}
])
// Aggregation operators
db.sales.aggregate([
{
$group: {
_id: "$category",
total: { $sum: "$amount" },
avg: { $avg: "$amount" },
min: { $min: "$amount" },
max: { $max: "$amount" },
count: { $sum: 1 },
items: { $push: "$productName" },
first: { $first: "$date" },
last: { $last: "$date" }
}
}
])
Indexing
Creating Indexes
// Single field index
db.users.createIndex({ email: 1 })
// Compound index
db.users.createIndex({ lastName: 1, firstName: 1 })
// Unique index
db.users.createIndex({ email: 1 }, { unique: true })
// Text index
db.posts.createIndex({ title: "text", content: "text" })
// 2dsphere index (geospatial)
db.locations.createIndex({ coordinates: "2dsphere" })
// TTL index (auto-delete after time)
db.sessions.createIndex(
{ createdAt: 1 },
{ expireAfterSeconds: 3600 }
)
// Partial index
db.orders.createIndex(
{ status: 1 },
{ partialFilterExpression: { status: "active" } }
)
// Sparse index
db.users.createIndex(
{ phoneNumber: 1 },
{ sparse: true }
)
Index Management
// List indexes
db.users.getIndexes()
// Drop index
db.users.dropIndex("email_1")
// Drop all indexes
db.users.dropIndexes()
// Explain query (check index usage)
db.users.find({ email: "john@example.com" }).explain("executionStats")
MongoDB with Node.js
Installation
npm install mongodb
# or
npm install mongoose
Native MongoDB Driver
const { MongoClient } = require('mongodb');
const url = 'mongodb://localhost:27017';
const client = new MongoClient(url);
async function main() {
await client.connect();
console.log('Connected to MongoDB');
const db = client.db('mydb');
const users = db.collection('users');
// Insert
const result = await users.insertOne({
name: 'John Doe',
email: 'john@example.com',
age: 30
});
console.log('Inserted:', result.insertedId);
// Find
const user = await users.findOne({ email: 'john@example.com' });
console.log('Found:', user);
// Update
await users.updateOne(
{ email: 'john@example.com' },
{ $set: { age: 31 } }
);
// Delete
await users.deleteOne({ email: 'john@example.com' });
await client.close();
}
main().catch(console.error);
Mongoose ODM
const mongoose = require('mongoose');
// Connect
mongoose.connect('mongodb://localhost:27017/mydb', {
useNewUrlParser: true,
useUnifiedTopology: true
});
// Define schema
const userSchema = new mongoose.Schema({
name: { type: String, required: true },
email: { type: String, required: true, unique: true },
age: { type: Number, min: 0, max: 120 },
createdAt: { type: Date, default: Date.now },
address: {
street: String,
city: String,
state: String,
zip: String
},
hobbies: [String],
status: {
type: String,
enum: ['active', 'inactive', 'banned'],
default: 'active'
}
});
// Instance methods
userSchema.methods.getFullInfo = function() {
return `${this.name} (${this.email})`;
};
// Static methods
userSchema.statics.findByEmail = function(email) {
return this.findOne({ email });
};
// Virtuals
userSchema.virtual('isAdult').get(function() {
return this.age >= 18;
});
// Middleware
userSchema.pre('save', function(next) {
console.log('About to save user:', this.name);
next();
});
// Create model
const User = mongoose.model('User', userSchema);
// CRUD operations
async function examples() {
// Create
const user = new User({
name: 'John Doe',
email: 'john@example.com',
age: 30,
hobbies: ['reading', 'coding']
});
await user.save();
// Find
const users = await User.find({ age: { $gte: 25 } });
const john = await User.findByEmail('john@example.com');
// Update
await User.updateOne({ email: 'john@example.com' }, { age: 31 });
// or
john.age = 31;
await john.save();
// Delete
await User.deleteOne({ email: 'john@example.com' });
// Populate (references)
const postSchema = new mongoose.Schema({
title: String,
author: { type: mongoose.Schema.Types.ObjectId, ref: 'User' }
});
const Post = mongoose.model('Post', postSchema);
const posts = await Post.find().populate('author');
}
Express + Mongoose API
const express = require('express');
const mongoose = require('mongoose');
const app = express();
app.use(express.json());
// Connect to MongoDB
mongoose.connect('mongodb://localhost:27017/mydb');
// User model
const User = mongoose.model('User', new mongoose.Schema({
name: { type: String, required: true },
email: { type: String, required: true, unique: true },
age: Number
}));
// Routes
app.get('/users', async (req, res) => {
try {
const users = await User.find();
res.json(users);
} catch (error) {
res.status(500).json({ error: error.message });
}
});
app.get('/users/:id', async (req, res) => {
try {
const user = await User.findById(req.params.id);
if (!user) return res.status(404).json({ error: 'User not found' });
res.json(user);
} catch (error) {
res.status(500).json({ error: error.message });
}
});
app.post('/users', async (req, res) => {
try {
const user = new User(req.body);
await user.save();
res.status(201).json(user);
} catch (error) {
res.status(400).json({ error: error.message });
}
});
app.put('/users/:id', async (req, res) => {
try {
const user = await User.findByIdAndUpdate(
req.params.id,
req.body,
{ new: true, runValidators: true }
);
if (!user) return res.status(404).json({ error: 'User not found' });
res.json(user);
} catch (error) {
res.status(400).json({ error: error.message });
}
});
app.delete('/users/:id', async (req, res) => {
try {
const user = await User.findByIdAndDelete(req.params.id);
if (!user) return res.status(404).json({ error: 'User not found' });
res.json({ message: 'User deleted' });
} catch (error) {
res.status(500).json({ error: error.message });
}
});
app.listen(3000, () => console.log('Server running on port 3000'));
Best Practices
1. Schema Design
// Embed related data when:
// - Data is frequently accessed together
// - Data doesn't change often
// - Array size is bounded
// Reference when:
// - Data is frequently accessed separately
// - Data changes frequently
// - Array size is unbounded
2. Use Appropriate Indexes
// Index fields used in queries
db.users.createIndex({ email: 1 })
// Compound indexes for multi-field queries
db.users.createIndex({ status: 1, createdAt: -1 })
// Monitor index usage
db.users.aggregate([{ $indexStats: {} }])
3. Validate Data
// Mongoose validation
const userSchema = new mongoose.Schema({
email: {
type: String,
required: true,
validate: {
validator: function(v) {
return /^[\w-\.]+@([\w-]+\.)+[\w-]{2,4}$/.test(v);
},
message: props => `${props.value} is not a valid email!`
}
},
age: {
type: Number,
min: [0, 'Age must be positive'],
max: [120, 'Age seems unrealistic']
}
});
4. Handle Errors
try {
await User.create({ email: 'invalid' });
} catch (error) {
if (error.name === 'ValidationError') {
// Handle validation error
} else if (error.code === 11000) {
// Handle duplicate key error
}
}
5. Use Transactions (for multi-document operations)
const session = await mongoose.startSession();
session.startTransaction();
try {
await User.create([{ name: 'John' }], { session });
await Post.create([{ title: 'First Post' }], { session });
await session.commitTransaction();
} catch (error) {
await session.abortTransaction();
throw error;
} finally {
session.endSession();
}
Performance Optimization
1. Query Optimization
// Use projection
db.users.find({}, { name: 1, email: 1 })
// Use covered queries (query uses only indexed fields)
db.users.createIndex({ email: 1, name: 1 })
db.users.find({ email: 'john@example.com' }, { email: 1, name: 1, _id: 0 })
// Limit results
db.users.find().limit(10)
// Use lean() in Mongoose (skip hydration)
const users = await User.find().lean()
2. Connection Pooling
const mongoose = require('mongoose');
mongoose.connect('mongodb://localhost:27017/mydb', {
maxPoolSize: 10,
minPoolSize: 5
});
3. Batch Operations
// Bulk insert
db.users.insertMany([
{ name: 'User 1' },
{ name: 'User 2' },
{ name: 'User 3' }
], { ordered: false })
// Bulk write
db.users.bulkWrite([
{ insertOne: { document: { name: 'John' } } },
{ updateOne: { filter: { name: 'Jane' }, update: { $set: { age: 30 } } } },
{ deleteOne: { filter: { name: 'Bob' } } }
])
4. Caching
const Redis = require('redis');
const redis = Redis.createClient();
async function getUser(id) {
// Check cache first
const cached = await redis.get(`user:${id}`);
if (cached) return JSON.parse(cached);
// Query database
const user = await User.findById(id);
// Store in cache
await redis.setex(`user:${id}`, 3600, JSON.stringify(user));
return user;
}
Resources
Official Documentation:
Tools:
- MongoDB Compass - GUI
- Studio 3T - IDE for MongoDB
- mongosh - MongoDB Shell
Learning:
Redis
Redis (Remote Dictionary Server) is an open-source, in-memory data structure store used as a database, cache, message broker, and streaming engine. Known for its high performance and versatility, Redis supports various data structures and is widely used for real-time applications.
Table of Contents
- Introduction
- Installation and Setup
- Data Structures
- Common Operations
- Caching Strategies
- Pub/Sub Messaging
- Redis with Node.js
- Best Practices
- Performance and Persistence
Introduction
Key Features:
- In-memory data storage
- Sub-millisecond latency
- Multiple data structures (strings, hashes, lists, sets, sorted sets)
- Pub/Sub messaging
- Transactions
- Lua scripting
- Persistence options (RDB, AOF)
- Replication and high availability
- Clustering for horizontal scaling
Use Cases:
- Caching
- Session storage
- Real-time analytics
- Leaderboards and counting
- Rate limiting
- Message queues
- Real-time chat applications
- Geospatial data
Installation and Setup
Install Redis
macOS:
brew install redis
brew services start redis
Ubuntu:
sudo apt update
sudo apt install redis-server
sudo systemctl start redis-server
Docker:
docker run -d -p 6379:6379 --name redis redis:latest
Redis CLI
# Connect to Redis
redis-cli
# Test connection
127.0.0.1:6379> PING
PONG
# Select database (0-15)
127.0.0.1:6379> SELECT 1
# Get all keys
127.0.0.1:6379> KEYS *
# Clear database
127.0.0.1:6379> FLUSHDB
# Clear all databases
127.0.0.1:6379> FLUSHALL
Data Structures
Strings
# Set and get
SET name "John Doe"
GET name
# Set with expiration (seconds)
SETEX session:123 3600 "user_data"
# Set if not exists
SETNX key "value"
# Multiple set/get
MSET key1 "value1" key2 "value2"
MGET key1 key2
# Increment/decrement
SET counter 10
INCR counter # 11
INCRBY counter 5 # 16
DECR counter # 15
DECRBY counter 3 # 12
# Append
APPEND key "more_data"
# Get length
STRLEN key
Hashes (Objects)
# Set hash field
HSET user:1 name "John" age 30 email "john@example.com"
# Get hash field
HGET user:1 name
# Get all fields
HGETALL user:1
# Get multiple fields
HMGET user:1 name email
# Check if field exists
HEXISTS user:1 name
# Delete field
HDEL user:1 age
# Get all keys/values
HKEYS user:1
HVALS user:1
# Increment hash field
HINCRBY user:1 loginCount 1
Lists
# Push to list
LPUSH mylist "first" # Push to left
RPUSH mylist "last" # Push to right
# Pop from list
LPOP mylist # Pop from left
RPOP mylist # Pop from right
# Get range
LRANGE mylist 0 -1 # Get all
LRANGE mylist 0 9 # Get first 10
# Get by index
LINDEX mylist 0
# List length
LLEN mylist
# Trim list
LTRIM mylist 0 99 # Keep first 100 items
# Blocking pop (for queues)
BLPOP mylist 0 # Block until item available
Sets
# Add members
SADD myset "member1" "member2" "member3"
# Get all members
SMEMBERS myset
# Check membership
SISMEMBER myset "member1"
# Remove member
SREM myset "member1"
# Set operations
SUNION set1 set2 # Union
SINTER set1 set2 # Intersection
SDIFF set1 set2 # Difference
# Random member
SRANDMEMBER myset
SPOP myset # Pop random member
# Set size
SCARD myset
Sorted Sets (Leaderboards)
# Add members with scores
ZADD leaderboard 100 "player1" 200 "player2" 150 "player3"
# Get range by rank
ZRANGE leaderboard 0 9 # Top 10 (ascending)
ZREVRANGE leaderboard 0 9 # Top 10 (descending)
# Get range with scores
ZRANGE leaderboard 0 9 WITHSCORES
# Get rank
ZRANK leaderboard "player1" # Ascending rank
ZREVRANK leaderboard "player1" # Descending rank
# Get score
ZSCORE leaderboard "player1"
# Increment score
ZINCRBY leaderboard 50 "player1"
# Range by score
ZRANGEBYSCORE leaderboard 100 200
# Count in range
ZCOUNT leaderboard 100 200
# Remove member
ZREM leaderboard "player1"
Common Operations
Key Management
# Set expiration
EXPIRE key 60 # Expire in 60 seconds
EXPIREAT key 1609459200 # Expire at timestamp
TTL key # Get time to live
PERSIST key # Remove expiration
# Delete keys
DEL key1 key2 key3
# Check if key exists
EXISTS key
# Get key type
TYPE key
# Rename key
RENAME oldkey newkey
RENAMENX oldkey newkey # Rename if new key doesn't exist
# Get all keys matching pattern
KEYS user:*
SCAN 0 MATCH user:* COUNT 10 # Better for production
Transactions
MULTI
SET key1 "value1"
SET key2 "value2"
INCR counter
EXEC
# With watch (optimistic locking)
WATCH key
MULTI
SET key "new_value"
EXEC
Caching Strategies
Cache-Aside (Lazy Loading)
async function getUser(id) {
const cacheKey = `user:${id}`;
// Try cache first
let user = await redis.get(cacheKey);
if (user) {
return JSON.parse(user);
}
// Cache miss - load from database
user = await db.users.findById(id);
// Store in cache
await redis.setex(cacheKey, 3600, JSON.stringify(user));
return user;
}
Write-Through Cache
async function updateUser(id, data) {
const cacheKey = `user:${id}`;
// Update database
const user = await db.users.updateById(id, data);
// Update cache
await redis.setex(cacheKey, 3600, JSON.stringify(user));
return user;
}
Write-Behind (Write-Back) Cache
async function updateUser(id, data) {
const cacheKey = `user:${id}`;
// Update cache immediately
await redis.setex(cacheKey, 3600, JSON.stringify(data));
// Queue database write
await redis.lpush('user:updates', JSON.stringify({ id, data }));
return data;
}
// Background worker
async function processUpdates() {
while (true) {
const update = await redis.brpop('user:updates', 0);
if (update) {
const { id, data } = JSON.parse(update[1]);
await db.users.updateById(id, data);
}
}
}
Pub/Sub Messaging
Basic Pub/Sub
const redis = require('redis');
// Publisher
const publisher = redis.createClient();
publisher.publish('news', 'Breaking news!');
// Subscriber
const subscriber = redis.createClient();
subscriber.subscribe('news');
subscriber.on('message', (channel, message) => {
console.log(`Received from ${channel}: ${message}`);
});
// Pattern subscribe
subscriber.psubscribe('user:*');
subscriber.on('pmessage', (pattern, channel, message) => {
console.log(`Pattern ${pattern}, Channel ${channel}: ${message}`);
});
Real-Time Chat Example
const express = require('express');
const http = require('http');
const socketIo = require('socket.io');
const redis = require('redis');
const app = express();
const server = http.createServer(app);
const io = socketIo(server);
const publisher = redis.createClient();
const subscriber = redis.createClient();
subscriber.subscribe('chat:messages');
// Handle Redis messages
subscriber.on('message', (channel, message) => {
if (channel === 'chat:messages') {
io.emit('message', JSON.parse(message));
}
});
// Handle WebSocket connections
io.on('connection', (socket) => {
console.log('User connected');
socket.on('message', (msg) => {
const message = {
user: socket.id,
text: msg,
timestamp: Date.now()
};
// Publish to Redis
publisher.publish('chat:messages', JSON.stringify(message));
// Store in list
redis.lpush('chat:history', JSON.stringify(message));
redis.ltrim('chat:history', 0, 99); // Keep last 100 messages
});
socket.on('disconnect', () => {
console.log('User disconnected');
});
});
server.listen(3000);
Redis with Node.js
Using node-redis
npm install redis
Basic Usage:
const redis = require('redis');
const client = redis.createClient({
url: 'redis://localhost:6379'
});
client.on('error', (err) => console.error('Redis error:', err));
client.on('connect', () => console.log('Connected to Redis'));
await client.connect();
// String operations
await client.set('key', 'value');
const value = await client.get('key');
// Hash operations
await client.hSet('user:1', 'name', 'John');
await client.hSet('user:1', 'age', '30');
const user = await client.hGetAll('user:1');
// List operations
await client.lPush('mylist', 'item1');
await client.rPush('mylist', 'item2');
const items = await client.lRange('mylist', 0, -1);
// Set operations
await client.sAdd('myset', 'member1');
await client.sAdd('myset', 'member2');
const members = await client.sMembers('myset');
// Sorted set operations
await client.zAdd('leaderboard', { score: 100, value: 'player1' });
const top = await client.zRange('leaderboard', 0, 9, { REV: true });
await client.disconnect();
Caching Middleware (Express)
const redis = require('redis');
const client = redis.createClient();
await client.connect();
function cache(duration) {
return async (req, res, next) => {
const key = `cache:${req.originalUrl}`;
try {
const cachedResponse = await client.get(key);
if (cachedResponse) {
return res.json(JSON.parse(cachedResponse));
}
// Modify res.json to cache response
const originalJson = res.json.bind(res);
res.json = (body) => {
client.setex(key, duration, JSON.stringify(body));
return originalJson(body);
};
next();
} catch (error) {
next();
}
};
}
// Usage
app.get('/api/users', cache(300), async (req, res) => {
const users = await db.users.findAll();
res.json(users);
});
Session Storage
const session = require('express-session');
const RedisStore = require('connect-redis').default;
const redis = require('redis');
const redisClient = redis.createClient();
await redisClient.connect();
app.use(
session({
store: new RedisStore({ client: redisClient }),
secret: 'your-secret',
resave: false,
saveUninitialized: false,
cookie: {
secure: false, // Set true for HTTPS
httpOnly: true,
maxAge: 1000 * 60 * 60 * 24 // 1 day
}
})
);
app.get('/', (req, res) => {
if (req.session.views) {
req.session.views++;
} else {
req.session.views = 1;
}
res.send(`Views: ${req.session.views}`);
});
Rate Limiting
async function rateLimiter(userId, maxRequests = 10, windowSeconds = 60) {
const key = `rate_limit:${userId}`;
const current = await client.incr(key);
if (current === 1) {
await client.expire(key, windowSeconds);
}
if (current > maxRequests) {
const ttl = await client.ttl(key);
throw new Error(`Rate limit exceeded. Try again in ${ttl} seconds`);
}
return {
remaining: maxRequests - current,
reset: windowSeconds
};
}
// Middleware
async function rateLimitMiddleware(req, res, next) {
const userId = req.user?.id || req.ip;
try {
const result = await rateLimiter(userId);
res.set('X-RateLimit-Remaining', result.remaining);
res.set('X-RateLimit-Reset', result.reset);
next();
} catch (error) {
res.status(429).json({ error: error.message });
}
}
app.use(rateLimitMiddleware);
Distributed Locking
async function acquireLock(lockKey, timeout = 10000) {
const lockValue = Math.random().toString(36);
const result = await client.set(lockKey, lockValue, {
NX: true,
PX: timeout
});
if (result === 'OK') {
return lockValue;
}
return null;
}
async function releaseLock(lockKey, lockValue) {
const script = `
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
`;
return await client.eval(script, {
keys: [lockKey],
arguments: [lockValue]
});
}
// Usage
async function criticalSection() {
const lock = await acquireLock('resource:lock');
if (!lock) {
throw new Error('Could not acquire lock');
}
try {
// Perform critical operation
await performOperation();
} finally {
await releaseLock('resource:lock', lock);
}
}
Best Practices
1. Key Naming Conventions
# Use descriptive, hierarchical names
user:1:profile
user:1:sessions
order:12345:items
# Use consistent separators
user:1:profile # Colon-separated
user_1_profile # Underscore-separated
# Include type in key name
string:user:1:name
hash:user:1
list:user:1:notifications
2. Set Expiration Times
// Always set TTL for cache keys
await client.setex('cache:key', 3600, 'value');
// Use appropriate expiration times
const MINUTE = 60;
const HOUR = 60 * MINUTE;
const DAY = 24 * HOUR;
await client.setex('session:123', 30 * MINUTE, data);
await client.setex('cache:user:1', 1 * HOUR, data);
await client.setex('temp:verification', 10 * MINUTE, code);
3. Use Pipelines for Multiple Commands
const pipeline = client.multi();
pipeline.set('key1', 'value1');
pipeline.set('key2', 'value2');
pipeline.incr('counter');
pipeline.hSet('user:1', 'name', 'John');
const results = await pipeline.exec();
4. Handle Connection Errors
const client = redis.createClient({
url: 'redis://localhost:6379',
socket: {
reconnectStrategy: (retries) => {
if (retries > 10) {
return new Error('Max retries reached');
}
return retries * 100;
}
}
});
client.on('error', (err) => {
console.error('Redis error:', err);
});
client.on('reconnecting', () => {
console.log('Reconnecting to Redis...');
});
client.on('ready', () => {
console.log('Redis is ready');
});
5. Memory Management
// Set maxmemory and eviction policy in redis.conf
// maxmemory 256mb
// maxmemory-policy allkeys-lru
// Monitor memory usage
const info = await client.info('memory');
console.log(info);
// Use SCAN instead of KEYS
let cursor = 0;
do {
const result = await client.scan(cursor, {
MATCH: 'user:*',
COUNT: 100
});
cursor = result.cursor;
const keys = result.keys;
// Process keys
} while (cursor !== 0);
Performance and Persistence
Persistence Options
RDB (Redis Database Backup):
# redis.conf
save 900 1 # Save if 1 key changed in 15 minutes
save 300 10 # Save if 10 keys changed in 5 minutes
save 60 10000 # Save if 10000 keys changed in 1 minute
dbfilename dump.rdb
dir /var/lib/redis
AOF (Append Only File):
# redis.conf
appendonly yes
appendfilename "appendonly.aof"
# Sync strategy
appendfsync always # Slowest, safest
appendfsync everysec # Good balance (recommended)
appendfsync no # Fastest, least safe
Replication
# On replica
redis-cli
> REPLICAOF master-host 6379
# Check replication status
> INFO replication
Monitoring
// Monitor commands
client.monitor((err, res) => {
console.log(res);
});
// Get stats
const info = await client.info();
console.log(info);
// Slow log
const slowlog = await client.slowlog('GET', 10);
console.log(slowlog);
Resources
Official Documentation:
Tools:
- RedisInsight - GUI
- redis-cli - Command line
- redis-benchmark - Performance testing
Learning:
Apache Kafka
Apache Kafka is a distributed event streaming platform capable of handling trillions of events a day. It's used for building real-time data pipelines and streaming applications, providing high-throughput, fault-tolerant, and scalable messaging.
Table of Contents
- Introduction
- Core Concepts
- Installation and Setup
- Producers
- Consumers
- Topics and Partitions
- Kafka with Node.js
- Best Practices
- Production Considerations
Introduction
Key Features:
- High-throughput message streaming
- Fault-tolerant and durable
- Horizontal scalability
- Low latency (sub-millisecond)
- Replay capability
- Stream processing with Kafka Streams
- Connect framework for integrations
Use Cases:
- Event-driven architectures
- Log aggregation
- Real-time analytics
- Change Data Capture (CDC)
- Microservices communication
- Stream processing
- Message queuing
- Activity tracking
Core Concepts
Topics
Logical channels for messages, similar to database tables.
Partitions
Topics are split into partitions for parallel processing.
Producers
Applications that publish messages to topics.
Consumers
Applications that subscribe to topics and process messages.
Consumer Groups
Multiple consumers working together to process messages from a topic.
Brokers
Kafka servers that store and serve data.
Zookeeper/KRaft
Coordination service for managing Kafka cluster (KRaft is the newer alternative).
Installation and Setup
Docker Compose Setup
docker-compose.yml:
version: '3'
services:
zookeeper:
image: confluentinc/cp-zookeeper:latest
environment:
ZOOKEEPER_CLIENT_PORT: 2181
ZOOKEEPER_TICK_TIME: 2000
kafka:
image: confluentinc/cp-kafka:latest
depends_on:
- zookeeper
ports:
- "9092:9092"
environment:
KAFKA_BROKER_ID: 1
KAFKA_ZOOKEEPER_CONNECT: zookeeper:2181
KAFKA_ADVERTISED_LISTENERS: PLAINTEXT://localhost:9092
KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR: 1
docker-compose up -d
CLI Commands
# Create topic
kafka-topics --create \
--bootstrap-server localhost:9092 \
--topic my-topic \
--partitions 3 \
--replication-factor 1
# List topics
kafka-topics --list --bootstrap-server localhost:9092
# Describe topic
kafka-topics --describe \
--bootstrap-server localhost:9092 \
--topic my-topic
# Delete topic
kafka-topics --delete \
--bootstrap-server localhost:9092 \
--topic my-topic
# Produce messages
kafka-console-producer \
--bootstrap-server localhost:9092 \
--topic my-topic
# Consume messages
kafka-console-consumer \
--bootstrap-server localhost:9092 \
--topic my-topic \
--from-beginning
Producers
Basic Producer Concept
// Producer sends messages to topics
Message → Producer → Kafka Broker → Topic Partition
Producer Configuration
{
'bootstrap.servers': 'localhost:9092',
'client.id': 'my-producer',
'acks': 'all', // Wait for all replicas
'compression.type': 'gzip', // Compress messages
'max.in.flight.requests.per.connection': 5,
'retries': 3, // Retry failed sends
'batch.size': 16384, // Batch size in bytes
'linger.ms': 10 // Wait time before sending batch
}
Consumers
Basic Consumer Concept
// Consumers read messages from topics
Kafka Broker → Topic Partition → Consumer Group → Consumer
Consumer Groups
- Multiple consumers in a group share the workload
- Each partition is consumed by only one consumer in a group
- Enables parallel processing and fault tolerance
Consumer Configuration
{
'bootstrap.servers': 'localhost:9092',
'group.id': 'my-consumer-group',
'auto.offset.reset': 'earliest', // Start from beginning if no offset
'enable.auto.commit': false, // Manual commit for reliability
'max.poll.records': 500 // Max records per poll
}
Topics and Partitions
Topic Design
// Good topic naming
user.events
order.created
payment.processed
notification.email.sent
// Partition strategy
// - More partitions = more parallelism
// - But more partitions = more overhead
// Start with: partitions = throughput (MB/s) / partition throughput (MB/s)
Message Keys
// Messages with same key go to same partition
// Ensures ordering for related events
{
key: 'user:123', // All events for user 123 in same partition
value: { ... }
}
Kafka with Node.js
Installation
npm install kafkajs
Producer Example
const { Kafka } = require('kafkajs');
const kafka = new Kafka({
clientId: 'my-app',
brokers: ['localhost:9092']
});
const producer = kafka.producer();
async function sendMessage() {
await producer.connect();
// Send single message
await producer.send({
topic: 'user-events',
messages: [
{
key: 'user:123',
value: JSON.stringify({
userId: 123,
action: 'login',
timestamp: Date.now()
})
}
]
});
// Send multiple messages
await producer.sendBatch({
topicMessages: [
{
topic: 'user-events',
messages: [
{ key: 'user:123', value: JSON.stringify({ action: 'login' }) },
{ key: 'user:124', value: JSON.stringify({ action: 'logout' }) }
]
}
]
});
await producer.disconnect();
}
sendMessage().catch(console.error);
Consumer Example
const { Kafka } = require('kafkajs');
const kafka = new Kafka({
clientId: 'my-app',
brokers: ['localhost:9092']
});
const consumer = kafka.consumer({
groupId: 'my-consumer-group'
});
async function consume() {
await consumer.connect();
await consumer.subscribe({
topic: 'user-events',
fromBeginning: true
});
await consumer.run({
eachMessage: async ({ topic, partition, message }) => {
console.log({
topic,
partition,
offset: message.offset,
key: message.key?.toString(),
value: message.value.toString()
});
// Process message
const event = JSON.parse(message.value.toString());
await processEvent(event);
}
});
}
async function processEvent(event) {
console.log('Processing:', event);
// Your business logic here
}
consume().catch(console.error);
Batch Processing
await consumer.run({
eachBatch: async ({
batch,
resolveOffset,
heartbeat,
isRunning,
isStale
}) => {
const messages = batch.messages;
for (let message of messages) {
if (!isRunning() || isStale()) break;
await processMessage(message);
// Commit offset for this message
resolveOffset(message.offset);
// Send heartbeat to keep consumer alive
await heartbeat();
}
}
});
Error Handling
const consumer = kafka.consumer({
groupId: 'my-group',
retry: {
retries: 8,
initialRetryTime: 100,
multiplier: 2
}
});
consumer.on('consumer.crash', async (event) => {
console.error('Consumer crashed:', event);
// Implement restart logic
});
await consumer.run({
eachMessage: async ({ topic, partition, message }) => {
try {
await processMessage(message);
} catch (error) {
console.error('Processing error:', error);
// Dead letter queue
await producer.send({
topic: 'dead-letter-queue',
messages: [{
key: message.key,
value: message.value,
headers: {
originalTopic: topic,
error: error.message
}
}]
});
}
}
});
Express Integration
const express = require('express');
const { Kafka } = require('kafkajs');
const app = express();
app.use(express.json());
const kafka = new Kafka({
clientId: 'api-server',
brokers: ['localhost:9092']
});
const producer = kafka.producer();
// Connect producer on startup
producer.connect();
// API endpoint to publish events
app.post('/api/events', async (req, res) => {
try {
const { userId, action, data } = req.body;
await producer.send({
topic: 'user-events',
messages: [{
key: `user:${userId}`,
value: JSON.stringify({
userId,
action,
data,
timestamp: Date.now()
})
}]
});
res.json({ success: true, message: 'Event published' });
} catch (error) {
res.status(500).json({ error: error.message });
}
});
// Graceful shutdown
process.on('SIGTERM', async () => {
await producer.disconnect();
process.exit(0);
});
app.listen(3000);
Microservices Communication
Order Service (Producer):
// order-service/producer.js
const { Kafka } = require('kafkajs');
const kafka = new Kafka({
clientId: 'order-service',
brokers: ['localhost:9092']
});
const producer = kafka.producer();
async function createOrder(orderData) {
await producer.connect();
// Publish order created event
await producer.send({
topic: 'order.created',
messages: [{
key: `order:${orderData.id}`,
value: JSON.stringify(orderData)
}]
});
console.log('Order created event published');
}
Inventory Service (Consumer):
// inventory-service/consumer.js
const { Kafka } = require('kafkajs');
const kafka = new Kafka({
clientId: 'inventory-service',
brokers: ['localhost:9092']
});
const consumer = kafka.consumer({
groupId: 'inventory-service-group'
});
async function start() {
await consumer.connect();
await consumer.subscribe({ topic: 'order.created' });
await consumer.run({
eachMessage: async ({ message }) => {
const order = JSON.parse(message.value.toString());
console.log('Processing order:', order.id);
// Update inventory
await updateInventory(order.items);
// Publish inventory updated event
await producer.send({
topic: 'inventory.updated',
messages: [{
key: `order:${order.id}`,
value: JSON.stringify({
orderId: order.id,
status: 'inventory_reserved'
})
}]
});
}
});
}
start().catch(console.error);
Best Practices
1. Message Design
// Include metadata
{
id: 'uuid',
type: 'order.created',
timestamp: 1234567890,
version: '1.0',
data: {
orderId: 123,
userId: 456,
items: [...]
}
}
// Use schema registry for validation
// Use Avro or Protobuf for efficient serialization
2. Error Handling
// Implement retry logic
async function processWithRetry(message, maxRetries = 3) {
for (let attempt = 1; attempt <= maxRetries; attempt++) {
try {
await processMessage(message);
return;
} catch (error) {
if (attempt === maxRetries) {
// Send to dead letter queue
await sendToDeadLetterQueue(message, error);
} else {
await sleep(Math.pow(2, attempt) * 1000); // Exponential backoff
}
}
}
}
3. Consumer Groups
// Use consumer groups for scalability
// Same group = load balancing
// Different groups = broadcast
const consumer = kafka.consumer({
groupId: 'order-processing-group'
});
4. Idempotency
// Ensure idempotent message processing
async function processMessage(message) {
const messageId = message.headers.messageId;
// Check if already processed
const processed = await redis.get(`processed:${messageId}`);
if (processed) {
console.log('Message already processed');
return;
}
// Process message
await doWork(message);
// Mark as processed
await redis.set(`processed:${messageId}`, '1', 'EX', 86400);
}
5. Monitoring
const producer = kafka.producer({
// Enable metrics
metricReporters: [
{
name: 'my-metrics',
interval: 5000,
async report(event) {
console.log('Metrics:', event);
}
}
]
});
// Monitor lag
await admin.fetchOffsets({
groupId: 'my-group',
topics: ['my-topic']
});
Production Considerations
High Availability
// Multiple brokers for redundancy
const kafka = new Kafka({
clientId: 'my-app',
brokers: [
'kafka1:9092',
'kafka2:9092',
'kafka3:9092'
],
retry: {
retries: 10,
initialRetryTime: 300,
multiplier: 2
}
});
// Replication factor for topics
await admin.createTopics({
topics: [{
topic: 'critical-events',
numPartitions: 6,
replicationFactor: 3 // Data replicated on 3 brokers
}]
});
Performance Tuning
// Producer optimization
const producer = kafka.producer({
idempotent: true, // Exactly-once semantics
maxInFlightRequests: 5,
compression: CompressionTypes.GZIP,
batch: {
size: 16384,
lingerMs: 10
}
});
// Consumer optimization
const consumer = kafka.consumer({
groupId: 'my-group',
sessionTimeout: 30000,
heartbeatInterval: 3000,
maxBytesPerPartition: 1048576,
maxWaitTimeInMs: 5000
});
Security
const kafka = new Kafka({
clientId: 'secure-app',
brokers: ['kafka:9093'],
ssl: true,
sasl: {
mechanism: 'plain',
username: 'my-username',
password: 'my-password'
}
});
Resources
Official Documentation:
Tools:
- Kafka UI - Web UI for Kafka
- Kafdrop - Kafka Web UI
- Kafka Tool - GUI
Learning:
Web Development
Modern web development covering frontend, backend, APIs, and full-stack technologies.
Topics Covered
Frontend Frameworks
- React: Components, hooks, state management
- Next.js: Production-ready React framework with SSR, SSG, and API routes
- Vue.js: Progressive framework with Composition API
- Svelte: Compiled framework with reactive programming
- Tailwind CSS: Utility-first CSS framework
Backend Frameworks
Node.js Frameworks
- Express.js: Minimal and flexible Node.js web framework
- NestJS: TypeScript-first progressive Node.js framework with dependency injection
Python Frameworks
- Django: High-level Python web framework for rapid development
- Flask: Lightweight and flexible Python microframework
- FastAPI: Modern, fast Python framework with automatic API documentation
Browser APIs
- Web APIs: Browser APIs for storage, workers, notifications, and more
- Storage: localStorage, sessionStorage, IndexedDB, Cache API
- Workers: Web Workers, Service Workers, Shared Workers
- Notifications: Notification API, Push API
- Device APIs: Geolocation, Battery Status
- File APIs: File, Blob, FileReader
- Observers: Intersection, Mutation, Resize Observer
- Other: Clipboard, History, Performance, Page Visibility
API & Communication
- REST APIs: RESTful API design and best practices
- GraphQL: Query language, schema design
- gRPC: High-performance RPC framework with Protocol Buffers
Other Topics
- Frontend: HTML, CSS, JavaScript fundamentals
- Backend: Express.js, Node.js
- Authentication: JWT, OAuth, sessions
- Databases: Integration with web apps
- Deployment: Hosting, CI/CD for web
Frontend Stack
- HTML/CSS/JavaScript
- React, Vue, or Angular
- State management (Redux, Vuex)
- Build tools (Webpack, Vite)
Backend Stack
- Node.js/Express or Python/Django
- REST or GraphQL APIs
- Database integration
- Authentication & authorization
Full Stack
Combining frontend and backend to build complete applications.
Navigation
Explore each topic to build modern web applications.
React
Overview
React is a JavaScript library for building user interfaces with reusable components and efficient rendering.
Components
Functional Components (Modern)
function Welcome({ name }) {
return <h1>Hello, {name}!</h1>;
}
// Arrow function
const Greeting = ({ message }) => <p>{message}</p>;
Class Components (Legacy)
class Welcome extends React.Component {
render() {
return <h1>Hello, {this.props.name}!</h1>;
}
}
Hooks
Modern way to manage state and effects:
import { useState, useEffect } from 'react';
function Counter() {
const [count, setCount] = useState(0);
useEffect(() => {
console.log('Count changed:', count);
// Cleanup
return () => console.log('Cleanup');
}, [count]); // Dependencies
return (
<div>
<p>Count: {count}</p>
<button onClick={() => setCount(count + 1)}>Increment</button>
</div>
);
}
Common Hooks
| Hook | Purpose |
|---|---|
| useState | Manage state |
| useEffect | Side effects |
| useContext | Access context |
| useReducer | Complex state logic |
| useCallback | Memoize function |
| useMemo | Memoize value |
Props
// Parent
<Child name="John" age={30} onClick={handleClick} />
// Child
function Child({ name, age, onClick }) {
return (
<div onClick={onClick}>
{name} is {age}
</div>
);
}
Conditional Rendering
{isLoggedIn && <Dashboard />}
{user ? <UserProfile /> : <LoginForm />}
{status === 'loading' && <Spinner />}
{status === 'error' && <Error />}
{status === 'success' && <Data />}
Lists
const users = [
{ id: 1, name: 'John' },
{ id: 2, name: 'Jane' }
];
<ul>
{users.map(user => (
<li key={user.id}>{user.name}</li>
))}
</ul>
Event Handling
function Button() {
const handleClick = (e) => {
console.log('Clicked');
};
const handleChange = (e) => {
const value = e.target.value;
};
return (
<>
<button onClick={handleClick}>Click</button>
<input onChange={handleChange} />
</>
);
}
Forms
function LoginForm() {
const [email, setEmail] = useState('');
const [password, setPassword] = useState('');
const handleSubmit = (e) => {
e.preventDefault();
console.log(email, password);
};
return (
<form onSubmit={handleSubmit}>
<input
type="email"
value={email}
onChange={(e) => setEmail(e.target.value)}
/>
<input
type="password"
value={password}
onChange={(e) => setPassword(e.target.value)}
/>
<button type="submit">Login</button>
</form>
);
}
State Management
Local State (useState)
const [state, setState] = useState(initialValue);
Context API (Global)
const UserContext = createContext();
function App() {
return (
<UserContext.Provider value={{ user: 'John' }}>
<Child />
</UserContext.Provider>
);
}
function Child() {
const { user } = useContext(UserContext);
}
Redux (Complex)
- Centralized store
- Actions → Reducers → State
Lifecycle (Class Components)
componentDidMount() { } // After render
componentDidUpdate() { } // After update
componentWillUnmount() { } // Before remove
Best Practices
- Functional components (with hooks)
- Keep components small
- Lift state up when needed
- Use keys in lists
- Memoize expensive computations
- Lazy load components
ELI10
React is like LEGO blocks:
- Build reusable pieces (components)
- Combine to make complex UIs
- Reuse same piece many times
- Efficient updates when data changes!
Further Resources
Next.js
Next.js is a production-ready React framework that provides server-side rendering, static site generation, API routes, and many other features out of the box. Built by Vercel, it's designed to give you the best developer experience with all the features needed for production.
Table of Contents
- Introduction
- Installation and Setup
- File-Based Routing
- Pages and Layouts
- Data Fetching
- API Routes
- Dynamic Routes
- Image Optimization
- CSS and Styling
- Authentication
- Deployment
- Best Practices
Introduction
Key Features:
- Server-Side Rendering (SSR)
- Static Site Generation (SSG)
- Incremental Static Regeneration (ISR)
- API Routes
- File-based routing
- Automatic code splitting
- Built-in image optimization
- TypeScript support
- Fast Refresh
- Zero configuration
Use Cases:
- E-commerce websites
- Marketing websites
- Blogs and content sites
- Dashboards
- SaaS applications
- Mobile applications (with React Native)
Installation and Setup
Create New Project
# Create Next.js app
npx create-next-app@latest my-next-app
cd my-next-app
# Or with TypeScript
npx create-next-app@latest my-next-app --typescript
# Start development server
npm run dev
Project Structure
my-next-app/
├── app/ # App directory (Next.js 13+)
│ ├── layout.tsx # Root layout
│ ├── page.tsx # Home page
│ ├── api/ # API routes
│ └── [folder]/ # Routes
├── public/ # Static files
├── components/ # React components
├── lib/ # Utility functions
├── styles/ # CSS files
├── next.config.js # Next.js configuration
├── package.json
└── tsconfig.json # TypeScript configuration
Configuration
next.config.js:
/** @type {import('next').NextConfig} */
const nextConfig = {
reactStrictMode: true,
images: {
domains: ['example.com', 'cdn.example.com'],
},
env: {
CUSTOM_KEY: process.env.CUSTOM_KEY,
},
async rewrites() {
return [
{
source: '/api/:path*',
destination: 'https://api.example.com/:path*',
},
]
},
}
module.exports = nextConfig
File-Based Routing
App Router (Next.js 13+)
app/
├── page.tsx # / route
├── about/
│ └── page.tsx # /about route
├── blog/
│ ├── page.tsx # /blog route
│ └── [slug]/
│ └── page.tsx # /blog/[slug] route
└── dashboard/
├── layout.tsx # Dashboard layout
├── page.tsx # /dashboard route
└── settings/
└── page.tsx # /dashboard/settings route
app/page.tsx:
import Link from 'next/link'
export default function Home() {
return (
<main>
<h1>Welcome to Next.js</h1>
<Link href="/about">About</Link>
<Link href="/blog">Blog</Link>
</main>
)
}
app/about/page.tsx:
export default function About() {
return (
<div>
<h1>About Us</h1>
<p>This is the about page</p>
</div>
)
}
Pages and Layouts
Root Layout
app/layout.tsx:
import type { Metadata } from 'next'
import { Inter } from 'next/font/google'
import './globals.css'
const inter = Inter({ subsets: ['latin'] })
export const metadata: Metadata = {
title: 'My Next.js App',
description: 'Built with Next.js',
}
export default function RootLayout({
children,
}: {
children: React.ReactNode
}) {
return (
<html lang="en">
<body className={inter.className}>
<nav>
<a href="/">Home</a>
<a href="/about">About</a>
<a href="/blog">Blog</a>
</nav>
{children}
<footer>© 2024 My App</footer>
</body>
</html>
)
}
Nested Layouts
app/dashboard/layout.tsx:
export default function DashboardLayout({
children,
}: {
children: React.ReactNode
}) {
return (
<div className="dashboard">
<aside>
<nav>
<a href="/dashboard">Overview</a>
<a href="/dashboard/settings">Settings</a>
<a href="/dashboard/profile">Profile</a>
</nav>
</aside>
<main>{children}</main>
</div>
)
}
Loading and Error States
app/loading.tsx:
export default function Loading() {
return <div>Loading...</div>
}
app/error.tsx:
'use client'
export default function Error({
error,
reset,
}: {
error: Error & { digest?: string }
reset: () => void
}) {
return (
<div>
<h2>Something went wrong!</h2>
<p>{error.message}</p>
<button onClick={reset}>Try again</button>
</div>
)
}
Data Fetching
Server Components (Default)
async function getData() {
const res = await fetch('https://api.example.com/data', {
cache: 'no-store', // or 'force-cache'
})
if (!res.ok) {
throw new Error('Failed to fetch data')
}
return res.json()
}
export default async function Page() {
const data = await getData()
return (
<div>
<h1>Data from API</h1>
<pre>{JSON.stringify(data, null, 2)}</pre>
</div>
)
}
Static Generation
async function getStaticData() {
const res = await fetch('https://api.example.com/posts')
return res.json()
}
export default async function BlogPage() {
const posts = await getStaticData()
return (
<div>
{posts.map((post: any) => (
<article key={post.id}>
<h2>{post.title}</h2>
<p>{post.excerpt}</p>
</article>
))}
</div>
)
}
// Revalidate every hour
export const revalidate = 3600
Dynamic Data with Params
async function getPost(slug: string) {
const res = await fetch(`https://api.example.com/posts/${slug}`)
return res.json()
}
export default async function Post({ params }: { params: { slug: string } }) {
const post = await getPost(params.slug)
return (
<article>
<h1>{post.title}</h1>
<div dangerouslySetInnerHTML={{ __html: post.content }} />
</article>
)
}
// Generate static params for dynamic routes
export async function generateStaticParams() {
const posts = await fetch('https://api.example.com/posts').then((res) =>
res.json()
)
return posts.map((post: any) => ({
slug: post.slug,
}))
}
Client Components
'use client'
import { useState, useEffect } from 'react'
export default function ClientComponent() {
const [data, setData] = useState(null)
const [loading, setLoading] = useState(true)
useEffect(() => {
fetch('/api/data')
.then((res) => res.json())
.then((data) => {
setData(data)
setLoading(false)
})
}, [])
if (loading) return <div>Loading...</div>
return <div>{JSON.stringify(data)}</div>
}
API Routes
Basic API Route
app/api/hello/route.ts:
import { NextResponse } from 'next/server'
export async function GET() {
return NextResponse.json({ message: 'Hello from Next.js!' })
}
export async function POST(request: Request) {
const body = await request.json()
return NextResponse.json({ received: body })
}
Dynamic API Routes
app/api/users/[id]/route.ts:
import { NextResponse } from 'next/server'
export async function GET(
request: Request,
{ params }: { params: { id: string } }
) {
const id = params.id
// Fetch user from database
const user = await fetchUser(id)
if (!user) {
return NextResponse.json({ error: 'User not found' }, { status: 404 })
}
return NextResponse.json(user)
}
export async function PUT(
request: Request,
{ params }: { params: { id: string } }
) {
const id = params.id
const body = await request.json()
// Update user in database
const updatedUser = await updateUser(id, body)
return NextResponse.json(updatedUser)
}
export async function DELETE(
request: Request,
{ params }: { params: { id: string } }
) {
const id = params.id
await deleteUser(id)
return NextResponse.json({ message: 'User deleted' })
}
API with Database
import { NextResponse } from 'next/server'
import { prisma } from '@/lib/prisma'
export async function GET() {
try {
const users = await prisma.user.findMany()
return NextResponse.json(users)
} catch (error) {
return NextResponse.json(
{ error: 'Failed to fetch users' },
{ status: 500 }
)
}
}
export async function POST(request: Request) {
try {
const body = await request.json()
const user = await prisma.user.create({
data: body,
})
return NextResponse.json(user, { status: 201 })
} catch (error) {
return NextResponse.json(
{ error: 'Failed to create user' },
{ status: 500 }
)
}
}
Dynamic Routes
Catch-All Routes
app/shop/[...slug]/page.tsx:
export default function ShopPage({ params }: { params: { slug: string[] } }) {
return (
<div>
<h1>Shop</h1>
<p>Category: {params.slug.join('/')}</p>
</div>
)
}
// Matches:
// /shop/electronics
// /shop/electronics/laptops
// /shop/electronics/laptops/gaming
Optional Catch-All Routes
app/docs/[[...slug]]/page.tsx:
export default function DocsPage({
params,
}: {
params: { slug?: string[] }
}) {
if (!params.slug) {
return <div>Documentation Home</div>
}
return <div>Path: {params.slug.join('/')}</div>
}
// Matches:
// /docs
// /docs/getting-started
// /docs/api/reference
Image Optimization
import Image from 'next/image'
export default function ImageExample() {
return (
<div>
{/* Static Image */}
<Image
src="/hero.jpg"
alt="Hero"
width={1200}
height={600}
priority
/>
{/* External Image */}
<Image
src="https://example.com/image.jpg"
alt="External"
width={800}
height={600}
quality={85}
/>
{/* Responsive Image */}
<Image
src="/profile.jpg"
alt="Profile"
fill
sizes="(max-width: 768px) 100vw, 50vw"
style={{ objectFit: 'cover' }}
/>
{/* With Placeholder */}
<Image
src="/photo.jpg"
alt="Photo"
width={600}
height={400}
placeholder="blur"
blurDataURL="data:image/jpeg;base64,..."
/>
</div>
)
}
CSS and Styling
CSS Modules
components/Button.module.css:
.button {
padding: 12px 24px;
background: blue;
color: white;
border: none;
border-radius: 4px;
cursor: pointer;
}
.button:hover {
background: darkblue;
}
components/Button.tsx:
import styles from './Button.module.css'
export default function Button({ children }: { children: React.ReactNode }) {
return <button className={styles.button}>{children}</button>
}
Tailwind CSS
npm install -D tailwindcss postcss autoprefixer
npx tailwindcss init -p
tailwind.config.js:
module.exports = {
content: [
'./app/**/*.{js,ts,jsx,tsx,mdx}',
'./components/**/*.{js,ts,jsx,tsx,mdx}',
],
theme: {
extend: {},
},
plugins: [],
}
app/globals.css:
@tailwind base;
@tailwind components;
@tailwind utilities;
Usage:
export default function Home() {
return (
<div className="min-h-screen bg-gray-100">
<h1 className="text-4xl font-bold text-blue-600">
Hello Tailwind!
</h1>
<button className="px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600">
Click Me
</button>
</div>
)
}
Authentication
NextAuth.js
npm install next-auth
app/api/auth/[...nextauth]/route.ts:
import NextAuth from 'next-auth'
import GoogleProvider from 'next-auth/providers/google'
import CredentialsProvider from 'next-auth/providers/credentials'
const handler = NextAuth({
providers: [
GoogleProvider({
clientId: process.env.GOOGLE_CLIENT_ID!,
clientSecret: process.env.GOOGLE_CLIENT_SECRET!,
}),
CredentialsProvider({
name: 'Credentials',
credentials: {
email: { label: "Email", type: "email" },
password: { label: "Password", type: "password" }
},
async authorize(credentials) {
// Verify credentials
const user = await verifyUser(credentials)
if (user) {
return user
}
return null
}
})
],
pages: {
signIn: '/auth/signin',
},
callbacks: {
async jwt({ token, user }) {
if (user) {
token.id = user.id
}
return token
},
async session({ session, token }) {
if (session.user) {
session.user.id = token.id as string
}
return session
},
},
})
export { handler as GET, handler as POST }
app/providers.tsx:
'use client'
import { SessionProvider } from 'next-auth/react'
export function Providers({ children }: { children: React.ReactNode }) {
return <SessionProvider>{children}</SessionProvider>
}
Protected Route:
import { getServerSession } from 'next-auth'
import { redirect } from 'next/navigation'
export default async function DashboardPage() {
const session = await getServerSession()
if (!session) {
redirect('/auth/signin')
}
return (
<div>
<h1>Dashboard</h1>
<p>Welcome, {session.user?.name}</p>
</div>
)
}
Client-Side Auth:
'use client'
import { useSession, signIn, signOut } from 'next-auth/react'
export default function LoginButton() {
const { data: session, status } = useSession()
if (status === 'loading') {
return <div>Loading...</div>
}
if (session) {
return (
<>
<p>Signed in as {session.user?.email}</p>
<button onClick={() => signOut()}>Sign out</button>
</>
)
}
return <button onClick={() => signIn()}>Sign in</button>
}
Deployment
Vercel (Recommended)
# Install Vercel CLI
npm i -g vercel
# Deploy
vercel
# Production deployment
vercel --prod
Docker
Dockerfile:
FROM node:18-alpine AS base
FROM base AS deps
WORKDIR /app
COPY package.json package-lock.json ./
RUN npm ci
FROM base AS builder
WORKDIR /app
COPY --from=deps /app/node_modules ./node_modules
COPY . .
RUN npm run build
FROM base AS runner
WORKDIR /app
ENV NODE_ENV production
COPY --from=builder /app/public ./public
COPY --from=builder /app/.next/standalone ./
COPY --from=builder /app/.next/static ./.next/static
EXPOSE 3000
CMD ["node", "server.js"]
next.config.js:
module.exports = {
output: 'standalone',
}
Environment Variables
.env.local:
DATABASE_URL="postgresql://..."
NEXTAUTH_SECRET="your-secret"
NEXTAUTH_URL="http://localhost:3000"
GOOGLE_CLIENT_ID="..."
GOOGLE_CLIENT_SECRET="..."
Best Practices
1. Server vs Client Components
// Server Component (default) - Use for:
// - Data fetching
// - Direct database access
// - API calls
export default async function ServerComponent() {
const data = await fetchData()
return <div>{data}</div>
}
// Client Component - Use for:
// - Interactivity (onClick, onChange, etc.)
// - State management
// - Browser APIs
'use client'
export default function ClientComponent() {
const [count, setCount] = useState(0)
return <button onClick={() => setCount(count + 1)}>{count}</button>
}
2. Data Fetching Strategies
// Static - Fetch at build time
export const revalidate = false
// ISR - Revalidate every 60 seconds
export const revalidate = 60
// Dynamic - Fetch on every request
export const dynamic = 'force-dynamic'
// Cache specific requests
fetch('https://api.example.com/data', {
next: { revalidate: 3600 } // Revalidate every hour
})
3. Metadata
import type { Metadata } from 'next'
export const metadata: Metadata = {
title: 'My Page',
description: 'Page description',
openGraph: {
title: 'My Page',
description: 'Page description',
images: ['/og-image.jpg'],
},
twitter: {
card: 'summary_large_image',
},
}
4. Error Boundaries
// app/error.tsx
'use client'
export default function Error({
error,
reset,
}: {
error: Error
reset: () => void
}) {
useEffect(() => {
console.error(error)
}, [error])
return (
<div>
<h2>Something went wrong!</h2>
<button onClick={reset}>Try again</button>
</div>
)
}
5. Performance Optimization
// Dynamic imports
import dynamic from 'next/dynamic'
const DynamicComponent = dynamic(() => import('@/components/Heavy'), {
loading: () => <p>Loading...</p>,
ssr: false, // Disable SSR for this component
})
// Lazy load images
<Image
src="/photo.jpg"
alt="Photo"
loading="lazy"
width={600}
height={400}
/>
Resources
Official Documentation:
Tools and Ecosystem:
Community:
Learning Resources:
Vue.js
Vue.js is a progressive JavaScript framework for building user interfaces. It's designed to be incrementally adoptable and focuses on the view layer.
Installation
# Create Vue 3 project
npm create vue@latest my-app
cd my-app
npm install
npm run dev
# Or via CDN
<script src="https://unpkg.com/vue@3"></script>
Component Basics
<!-- HelloWorld.vue -->
<template>
<div>
<h1>{{ message }}</h1>
<button @click="increment">Count: {{ count }}</button>
</div>
</template>
<script setup>
import { ref } from 'vue'
const message = ref('Hello Vue!')
const count = ref(0)
function increment() {
count.value++
}
</script>
<style scoped>
h1 {
color: #42b983;
}
</style>
Reactivity
<script setup>
import { ref, reactive, computed, watch } from 'vue'
// Refs
const count = ref(0)
// Reactive objects
const state = reactive({
name: 'John',
age: 30
})
// Computed properties
const doubled = computed(() => count.value * 2)
// Watchers
watch(count, (newVal, oldVal) => {
console.log(`Count changed from ${oldVal} to ${newVal}`)
})
</script>
Props and Emits
<!-- Child.vue -->
<script setup>
const props = defineProps({
title: String,
count: {
type: Number,
default: 0
}
})
const emit = defineEmits(['update', 'delete'])
function handleClick() {
emit('update', { id: 1, value: 'new' })
}
</script>
<template>
<h2>{{ title }}</h2>
<button @click="handleClick">Update</button>
</template>
<!-- Parent.vue -->
<Child
title="My Component"
:count="10"
@update="handleUpdate"
/>
Directives
<template>
<!-- Conditional rendering -->
<div v-if="show">Visible</div>
<div v-else>Hidden</div>
<!-- List rendering -->
<ul>
<li v-for="item in items" :key="item.id">
{{ item.name }}
</li>
</ul>
<!-- Two-way binding -->
<input v-model="text" />
<!-- Event handling -->
<button @click="handleClick">Click me</button>
<!-- Dynamic attributes -->
<img :src="imageUrl" :alt="description" />
</template>
Lifecycle Hooks
<script setup>
import { onMounted, onUpdated, onUnmounted } from 'vue'
onMounted(() => {
console.log('Component mounted')
})
onUpdated(() => {
console.log('Component updated')
})
onUnmounted(() => {
console.log('Component unmounted')
})
</script>
Quick Reference
| Feature | Syntax |
|---|---|
| Data binding | {{ variable }} |
| Attribute binding | :attribute="value" |
| Event handling | @event="handler" |
| Two-way binding | v-model="variable" |
| Conditional | v-if, v-else-if, v-else |
| Loop | v-for="item in items" |
Vue.js provides an approachable, versatile, and performant framework for building modern web interfaces.
Svelte
Svelte is a radical new approach to building user interfaces. Unlike frameworks that do the bulk of their work in the browser, Svelte shifts that work into a compile step.
Installation
# Create new Svelte project
npm create vite@latest my-app -- --template svelte
cd my-app
npm install
npm run dev
Component Basics
<!-- App.svelte -->
<script>
let count = 0;
function increment() {
count += 1;
}
</script>
<button on:click={increment}>
Clicked {count} {count === 1 ? 'time' : 'times'}
</button>
<style>
button {
background: #ff3e00;
color: white;
padding: 10px 20px;
border: none;
border-radius: 5px;
cursor: pointer;
}
</style>
Reactivity
<script>
let count = 0;
// Reactive declaration
$: doubled = count * 2;
// Reactive statement
$: if (count >= 10) {
alert('count is high!');
}
// Reactive block
$: {
console.log(`count is ${count}`);
}
</script>
Props
<!-- Child.svelte -->
<script>
export let name;
export let age = 25; // default value
</script>
<p>{name} is {age} years old</p>
<!-- Parent.svelte -->
<script>
import Child from './Child.svelte';
</script>
<Child name="John" age={30} />
Events
<script>
import { createEventDispatcher } from 'svelte';
const dispatch = createEventDispatcher();
function handleClick() {
dispatch('message', { text: 'Hello!' });
}
</script>
<button on:click={handleClick}>
Send message
</button>
<!-- Parent -->
<Child on:message={e => console.log(e.detail.text)} />
Stores
// store.js
import { writable } from 'svelte/store';
export const count = writable(0);
<script>
import { count } from './store.js';
</script>
<button on:click={() => $count += 1}>
Count: {$count}
</button>
Quick Reference
| Feature | Syntax |
|---|---|
| Reactive variable | $: value = ... |
| Event handler | on:click={handler} |
| Two-way binding | bind:value={variable} |
| Conditional | {#if condition}...{/if} |
| Loop | {#each items as item}...{/each} |
| Await | {#await promise}...{/await} |
Svelte compiles components to highly efficient imperative code, resulting in small bundle sizes and excellent performance.
Tailwind CSS
Tailwind CSS is a utility-first CSS framework for rapidly building custom user interfaces. Unlike traditional CSS frameworks that provide pre-designed components (like Bootstrap), Tailwind provides low-level utility classes that let you build completely custom designs without ever leaving your HTML.
Key Philosophy: Instead of fighting framework conventions, Tailwind gives you the building blocks to create your own design system with utility classes that can be composed to build any design directly in your markup.
Table of Contents
- Introduction
- Installation and Setup
- Configuration
- Core Concepts
- Utility Classes
- Responsive Design
- State Variants
- Dark Mode
- Component Patterns
- Layout Patterns
- Customization
- Plugin System
- Framework Integration
- Advanced Topics
- Performance Optimization
- Best Practices
- Accessibility
- Migration and Comparison
- Tooling and Ecosystem
- Resources
Introduction
What is Tailwind CSS?
Tailwind CSS is a utility-first CSS framework that provides single-purpose utility classes for building user interfaces. Instead of writing custom CSS, you compose these utilities directly in your HTML.
Traditional CSS approach:
<div class="chat-notification">
<div class="chat-notification-logo-wrapper">
<img class="chat-notification-logo" src="logo.svg" alt="Logo">
</div>
<div class="chat-notification-content">
<h4 class="chat-notification-title">New message</h4>
<p class="chat-notification-message">You have a new message!</p>
</div>
</div>
Tailwind approach:
<div class="flex items-center p-6 max-w-sm mx-auto bg-white rounded-xl shadow-lg">
<div class="shrink-0">
<img class="h-12 w-12" src="logo.svg" alt="Logo">
</div>
<div class="ml-4">
<h4 class="text-xl font-medium text-black">New message</h4>
<p class="text-gray-500">You have a new message!</p>
</div>
</div>
Key Features
- Utility-First: Compose designs from utility classes instead of writing custom CSS
- Responsive: Mobile-first breakpoints built into every utility
- Component-Friendly: Easy to extract components when needed
- Customizable: Extensive theming and configuration options
- Modern: Supports CSS Grid, Flexbox, transforms, transitions, and more
- Dark Mode: First-class dark mode support
- JIT Mode: Generate styles on-demand for faster builds
- Production-Optimized: Automatically removes unused CSS
Use Cases
Perfect for:
- Web applications and dashboards
- Marketing websites and landing pages
- Rapid prototyping
- Design systems and component libraries
- Projects requiring custom designs
Maybe not ideal for:
- Simple static sites (might be overkill)
- Teams resistant to utility-first approach
- Projects with very limited HTML access
Tailwind vs Traditional CSS
| Aspect | Tailwind | Traditional CSS |
|---|---|---|
| Approach | Utility-first | Semantic class names |
| Workflow | Compose in HTML | Write CSS separately |
| File Switching | Minimal | Constant (HTML ↔ CSS) |
| Naming | No naming needed | Need to invent class names |
| Bundle Size | Small (purged) | Grows over time |
| Customization | Config-based | Manual CSS |
| Learning Curve | Learn utilities | Learn CSS deeply |
Tailwind vs Bootstrap
| Feature | Tailwind | Bootstrap |
|---|---|---|
| Philosophy | Utility-first | Component-first |
| Customization | Highly flexible | Limited to overrides |
| Design | Build your own | Pre-designed look |
| File Size | Smaller (purged) | Larger base |
| Components | Build from utilities | Ready-made |
| Learning | Utility classes | Component classes |
Installation and Setup
NPM/Yarn Installation
# Install Tailwind CSS
npm install -D tailwindcss postcss autoprefixer
# Initialize configuration
npx tailwindcss init
Complete Setup
1. Create config files:
# Create both tailwind.config.js and postcss.config.js
npx tailwindcss init -p
2. Configure template paths (tailwind.config.js):
/** @type {import('tailwindcss').Config} */
module.exports = {
content: [
"./index.html",
"./src/**/*.{js,ts,jsx,tsx}",
],
theme: {
extend: {},
},
plugins: [],
}
3. Add Tailwind directives to CSS (src/index.css):
@tailwind base;
@tailwind components;
@tailwind utilities;
4. Import CSS in your app:
// main.js or App.jsx
import './index.css'
Framework-Specific Setup
React / Next.js
# Next.js (automatic with create-next-app)
npx create-next-app@latest my-project --tailwind
# Manual setup for existing React project
npm install -D tailwindcss postcss autoprefixer
npx tailwindcss init -p
Next.js config:
// tailwind.config.js
module.exports = {
content: [
'./pages/**/*.{js,ts,jsx,tsx,mdx}',
'./components/**/*.{js,ts,jsx,tsx,mdx}',
'./app/**/*.{js,ts,jsx,tsx,mdx}',
],
theme: {
extend: {},
},
plugins: [],
}
Vue / Nuxt
# Nuxt 3
npm install -D @nuxtjs/tailwindcss
nuxt.config.ts:
export default defineNuxtConfig({
modules: ['@nuxtjs/tailwindcss']
})
Svelte / SvelteKit
npx svelte-add@latest tailwindcss
npm install
Vite
npm install -D tailwindcss postcss autoprefixer
npx tailwindcss init -p
vite.config.js:
import { defineConfig } from 'vite'
export default defineConfig({
css: {
postcss: './postcss.config.js',
},
})
CDN (Development Only)
<!DOCTYPE html>
<html>
<head>
<!-- Include via CDN (no build step, but no customization) -->
<script src="https://cdn.tailwindcss.com"></script>
<!-- Optional: Configure via script tag -->
<script>
tailwind.config = {
theme: {
extend: {
colors: {
brand: '#3b82f6',
}
}
}
}
</script>
</head>
<body>
<h1 class="text-3xl font-bold text-brand">
Hello Tailwind!
</h1>
</body>
</html>
⚠️ CDN Warning: Don't use in production. No purging, no optimization, large file size.
Tailwind CLI
For projects without a build tool:
# Install
npm install -D tailwindcss
# Initialize
npx tailwindcss init
# Build CSS
npx tailwindcss -i ./src/input.css -o ./dist/output.css --watch
# Production build
npx tailwindcss -i ./src/input.css -o ./dist/output.css --minify
Configuration
Basic tailwind.config.js
/** @type {import('tailwindcss').Config} */
module.exports = {
// Files to scan for class names
content: [
"./index.html",
"./src/**/*.{js,jsx,ts,tsx,vue,svelte}",
],
// Dark mode configuration
darkMode: 'class', // or 'media'
// Theme customization
theme: {
// Replace default theme
screens: {
sm: '640px',
md: '768px',
lg: '1024px',
xl: '1280px',
'2xl': '1536px',
},
// Extend default theme (recommended)
extend: {
colors: {
brand: {
50: '#eff6ff',
100: '#dbeafe',
200: '#bfdbfe',
300: '#93c5fd',
400: '#60a5fa',
500: '#3b82f6',
600: '#2563eb',
700: '#1d4ed8',
800: '#1e40af',
900: '#1e3a8a',
},
},
spacing: {
'128': '32rem',
'144': '36rem',
},
borderRadius: {
'4xl': '2rem',
},
fontFamily: {
sans: ['Inter', 'sans-serif'],
display: ['Lexend', 'sans-serif'],
},
},
},
// Plugins
plugins: [],
}
Content Configuration
Tell Tailwind where to look for classes:
module.exports = {
content: [
// HTML files
'./public/**/*.html',
// JavaScript/TypeScript
'./src/**/*.{js,jsx,ts,tsx}',
// Vue components
'./src/**/*.vue',
// Svelte components
'./src/**/*.svelte',
// PHP files (for WordPress, Laravel, etc.)
'./templates/**/*.php',
// Use safelist for dynamic classes
],
// Safelist classes that might be generated dynamically
safelist: [
'bg-red-500',
'bg-green-500',
'bg-blue-500',
// Or use patterns
{
pattern: /bg-(red|green|blue)-(100|500|900)/,
},
],
}
Theme Extension
module.exports = {
theme: {
// Extend default theme (adds to existing)
extend: {
// Custom colors
colors: {
primary: '#3b82f6',
secondary: '#8b5cf6',
danger: '#ef4444',
},
// Custom spacing values
spacing: {
'128': '32rem',
'144': '36rem',
},
// Custom font sizes
fontSize: {
'xxs': '0.625rem',
},
// Custom breakpoints
screens: {
'3xl': '1920px',
},
// Custom z-index values
zIndex: {
'100': '100',
},
// Custom animations
animation: {
'spin-slow': 'spin 3s linear infinite',
},
// Custom keyframes
keyframes: {
wiggle: {
'0%, 100%': { transform: 'rotate(-3deg)' },
'50%': { transform: 'rotate(3deg)' },
}
}
},
// Replace default theme (use sparingly)
// screens: { ... } // This replaces all default breakpoints
},
}
Using CSS Variables
// tailwind.config.js
module.exports = {
theme: {
extend: {
colors: {
primary: 'var(--color-primary)',
secondary: 'var(--color-secondary)',
},
},
},
}
/* In your CSS */
:root {
--color-primary: 59 130 246; /* RGB values */
--color-secondary: 139 92 246;
}
.dark {
--color-primary: 96 165 250;
--color-secondary: 167 139 250;
}
<!-- Use with opacity modifiers -->
<div class="bg-primary/50">Semi-transparent background</div>
Core Concepts
Utility-First Fundamentals
Instead of semantic class names, use utilities:
<!-- ❌ Traditional approach -->
<div class="card">
<h2 class="card-title">Title</h2>
<p class="card-body">Content</p>
</div>
<!-- ✅ Tailwind approach -->
<div class="bg-white rounded-lg shadow-md p-6">
<h2 class="text-xl font-bold mb-2">Title</h2>
<p class="text-gray-700">Content</p>
</div>
Benefits:
- No need to invent class names
- Changes are local (no cascade issues)
- CSS bundle size stays small
- Faster development
Responsive Design (Mobile-First)
All utilities can be prefixed with breakpoint names:
<!-- Mobile: full width, Desktop: half width -->
<div class="w-full md:w-1/2">
Responsive element
</div>
<!-- Mobile: column, Tablet+: row -->
<div class="flex flex-col md:flex-row">
<div>Item 1</div>
<div>Item 2</div>
</div>
Breakpoints:
sm: 640pxmd: 768pxlg: 1024pxxl: 1280px2xl: 1536px
Hover, Focus, and Other States
<!-- Hover state -->
<button class="bg-blue-500 hover:bg-blue-700">
Hover me
</button>
<!-- Focus state -->
<input class="border focus:border-blue-500 focus:ring-2 focus:ring-blue-200">
<!-- Multiple states -->
<button class="bg-blue-500 hover:bg-blue-600 active:bg-blue-700 disabled:bg-gray-300">
Button
</button>
Design Tokens and Constraints
Tailwind provides a constrained set of values (design tokens) for consistency:
<!-- Spacing scale: 0, 1, 2, 3, 4, 5, 6, 8, 10, 12, 16, 20, 24, 32, 40, 48, 56, 64... -->
<div class="p-4"> <!-- padding: 1rem -->
<div class="p-8"> <!-- padding: 2rem -->
<div class="p-16"> <!-- padding: 4rem -->
<!-- Color scale: 50, 100, 200, 300, 400, 500, 600, 700, 800, 900 -->
<div class="bg-blue-100"> <!-- Light blue -->
<div class="bg-blue-500"> <!-- Medium blue -->
<div class="bg-blue-900"> <!-- Dark blue -->
Use arbitrary values when needed:
<!-- Arbitrary values with [value] syntax -->
<div class="w-[347px]">Exact width</div>
<div class="bg-[#1da1f2]">Twitter blue</div>
<div class="text-[2.35rem]">Custom font size</div>
Utility Classes
Layout
Container
<!-- Centered container with max-width -->
<div class="container mx-auto px-4">
Content
</div>
<!-- Responsive max-widths by default:
sm: 640px
md: 768px
lg: 1024px
xl: 1280px
2xl: 1536px
-->
Display
<!-- Block, inline, inline-block -->
<div class="block">Block</div>
<div class="inline">Inline</div>
<div class="inline-block">Inline-block</div>
<!-- Flex and Grid -->
<div class="flex">Flexbox container</div>
<div class="inline-flex">Inline flex container</div>
<div class="grid">Grid container</div>
<div class="inline-grid">Inline grid container</div>
<!-- Hidden -->
<div class="hidden">Not displayed</div>
<div class="md:block">Hidden on mobile, shown on tablet+</div>
Flexbox
<!-- Flex direction -->
<div class="flex flex-row">Horizontal (default)</div>
<div class="flex flex-col">Vertical</div>
<div class="flex flex-row-reverse">Reversed horizontal</div>
<!-- Justify content (main axis) -->
<div class="flex justify-start">Start</div>
<div class="flex justify-center">Center</div>
<div class="flex justify-between">Space between</div>
<div class="flex justify-around">Space around</div>
<div class="flex justify-evenly">Space evenly</div>
<!-- Align items (cross axis) -->
<div class="flex items-start">Start</div>
<div class="flex items-center">Center</div>
<div class="flex items-end">End</div>
<div class="flex items-stretch">Stretch (default)</div>
<!-- Flex wrap -->
<div class="flex flex-wrap">Wrap</div>
<div class="flex flex-nowrap">No wrap (default)</div>
<!-- Flex grow/shrink -->
<div class="flex-1">Grow and shrink</div>
<div class="flex-auto">Auto sizing</div>
<div class="flex-none">Don't grow or shrink</div>
<div class="grow">Only grow</div>
<div class="shrink-0">Don't shrink</div>
<!-- Gap -->
<div class="flex gap-4">Gap between items</div>
<div class="flex gap-x-4 gap-y-2">Different x and y gaps</div>
Grid
<!-- Grid columns -->
<div class="grid grid-cols-3 gap-4">
<div>1</div>
<div>2</div>
<div>3</div>
</div>
<!-- Grid cols with different sizes -->
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
Responsive grid
</div>
<!-- Column span -->
<div class="grid grid-cols-3">
<div class="col-span-2">Spans 2 columns</div>
<div>1 column</div>
</div>
<!-- Auto-fit columns -->
<div class="grid grid-cols-[repeat(auto-fit,minmax(200px,1fr))] gap-4">
Auto-sizing grid
</div>
<!-- Grid rows -->
<div class="grid grid-rows-3 gap-4 h-64">
<div>Row 1</div>
<div>Row 2</div>
<div>Row 3</div>
</div>
<!-- Grid template areas (arbitrary value) -->
<div class="grid grid-rows-[auto_1fr_auto]">
<header>Header</header>
<main>Content</main>
<footer>Footer</footer>
</div>
Position
<!-- Position types -->
<div class="static">Default</div>
<div class="relative">Relative</div>
<div class="absolute">Absolute</div>
<div class="fixed">Fixed</div>
<div class="sticky">Sticky</div>
<!-- Positioning with inset -->
<div class="absolute top-0 left-0">Top-left</div>
<div class="absolute top-0 right-0">Top-right</div>
<div class="absolute bottom-0 left-0">Bottom-left</div>
<div class="absolute inset-0">All sides 0</div>
<div class="absolute inset-x-0">Left and right 0</div>
<div class="absolute inset-y-0">Top and bottom 0</div>
<!-- Sticky header -->
<header class="sticky top-0 bg-white z-10">
Sticky navigation
</header>
Float and Clear
<div class="float-left">Float left</div>
<div class="float-right">Float right</div>
<div class="clear-both">Clear floats</div>
Spacing
Padding
<!-- All sides -->
<div class="p-4">Padding 1rem (16px)</div>
<div class="p-0">No padding</div>
<div class="p-px">1px padding</div>
<!-- Horizontal/Vertical -->
<div class="px-4">Horizontal padding</div>
<div class="py-2">Vertical padding</div>
<!-- Individual sides -->
<div class="pt-4">Padding top</div>
<div class="pr-4">Padding right</div>
<div class="pb-4">Padding bottom</div>
<div class="pl-4">Padding left</div>
<!-- Spacing scale: 0, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 96 */
Margin
<!-- All sides -->
<div class="m-4">Margin 1rem</div>
<div class="m-auto">Auto margin (for centering)</div>
<div class="-m-4">Negative margin</div>
<!-- Horizontal/Vertical -->
<div class="mx-auto">Center horizontally</div>
<div class="my-4">Vertical margin</div>
<!-- Individual sides -->
<div class="mt-4">Margin top</div>
<div class="mr-4">Margin right</div>
<div class="mb-4">Margin bottom</div>
<div class="ml-4">Margin left</div>
Space Between
<!-- Space between children (flex/grid) -->
<div class="flex space-x-4">
<div>Item 1</div>
<div>Item 2</div>
<div>Item 3</div>
</div>
<div class="flex flex-col space-y-4">
<div>Item 1</div>
<div>Item 2</div>
</div>
Sizing
Width
<!-- Fixed widths -->
<div class="w-32">Width 8rem (128px)</div>
<div class="w-64">Width 16rem (256px)</div>
<!-- Fractional widths -->
<div class="w-1/2">50% width</div>
<div class="w-1/3">33.333% width</div>
<div class="w-2/3">66.666% width</div>
<div class="w-1/4">25% width</div>
<div class="w-3/4">75% width</div>
<!-- Full widths -->
<div class="w-full">100% width</div>
<div class="w-screen">100vw width</div>
<!-- Min/Max width -->
<div class="min-w-0">Min-width 0</div>
<div class="min-w-full">Min-width 100%</div>
<div class="max-w-sm">Max-width 24rem</div>
<div class="max-w-md">Max-width 28rem</div>
<div class="max-w-lg">Max-width 32rem</div>
<div class="max-w-xl">Max-width 36rem</div>
<div class="max-w-2xl">Max-width 42rem</div>
<div class="max-w-full">Max-width 100%</div>
<div class="max-w-prose">Max-width 65ch (for reading)</div>
<!-- Arbitrary values -->
<div class="w-[420px]">Exact 420px</div>
Height
<!-- Fixed heights -->
<div class="h-32">Height 8rem</div>
<div class="h-64">Height 16rem</div>
<!-- Full heights -->
<div class="h-full">100% height</div>
<div class="h-screen">100vh height</div>
<!-- Min/Max height -->
<div class="min-h-screen">Min-height 100vh</div>
<div class="max-h-96">Max-height 24rem</div>
Typography
Font Family
<!-- Default font stacks -->
<p class="font-sans">Sans-serif font</p>
<p class="font-serif">Serif font</p>
<p class="font-mono">Monospace font</p>
<!-- Custom fonts (defined in config) -->
<p class="font-display">Display font</p>
Font Size
<p class="text-xs">Extra small (0.75rem)</p>
<p class="text-sm">Small (0.875rem)</p>
<p class="text-base">Base (1rem)</p>
<p class="text-lg">Large (1.125rem)</p>
<p class="text-xl">Extra large (1.25rem)</p>
<p class="text-2xl">2x large (1.5rem)</p>
<p class="text-3xl">3x large (1.875rem)</p>
<p class="text-4xl">4x large (2.25rem)</p>
<p class="text-5xl">5x large (3rem)</p>
<p class="text-6xl">6x large (3.75rem)</p>
<p class="text-7xl">7x large (4.5rem)</p>
<p class="text-8xl">8x large (6rem)</p>
<p class="text-9xl">9x large (8rem)</p>
Font Weight
<p class="font-thin">Thin (100)</p>
<p class="font-extralight">Extra light (200)</p>
<p class="font-light">Light (300)</p>
<p class="font-normal">Normal (400)</p>
<p class="font-medium">Medium (500)</p>
<p class="font-semibold">Semibold (600)</p>
<p class="font-bold">Bold (700)</p>
<p class="font-extrabold">Extra bold (800)</p>
<p class="font-black">Black (900)</p>
Text Alignment and Styling
<!-- Alignment -->
<p class="text-left">Left aligned</p>
<p class="text-center">Center aligned</p>
<p class="text-right">Right aligned</p>
<p class="text-justify">Justified</p>
<!-- Decoration -->
<p class="underline">Underlined</p>
<p class="line-through">Strikethrough</p>
<p class="no-underline">No underline</p>
<!-- Transform -->
<p class="uppercase">UPPERCASE</p>
<p class="lowercase">lowercase</p>
<p class="capitalize">Capitalize Each Word</p>
<p class="normal-case">Normal case</p>
<!-- Style -->
<p class="italic">Italic</p>
<p class="not-italic">Not italic</p>
Line Height and Letter Spacing
<!-- Line height -->
<p class="leading-none">Line height 1</p>
<p class="leading-tight">Line height 1.25</p>
<p class="leading-normal">Line height 1.5</p>
<p class="leading-loose">Line height 2</p>
<!-- Letter spacing -->
<p class="tracking-tighter">Very tight</p>
<p class="tracking-tight">Tight</p>
<p class="tracking-normal">Normal</p>
<p class="tracking-wide">Wide</p>
<p class="tracking-wider">Wider</p>
<p class="tracking-widest">Widest</p>
Text Overflow
<!-- Truncate with ellipsis -->
<p class="truncate">
This text will be truncated with ellipsis if it's too long
</p>
<!-- Overflow behavior -->
<p class="overflow-ellipsis">Ellipsis</p>
<p class="overflow-clip">Clip</p>
<!-- Whitespace -->
<p class="whitespace-normal">Normal</p>
<p class="whitespace-nowrap">No wrap</p>
<p class="whitespace-pre">Preserve whitespace</p>
<p class="whitespace-pre-wrap">Preserve and wrap</p>
Colors
Background Colors
<!-- Gray scale -->
<div class="bg-white">White</div>
<div class="bg-gray-50">Gray 50</div>
<div class="bg-gray-100">Gray 100</div>
<div class="bg-gray-500">Gray 500</div>
<div class="bg-gray-900">Gray 900</div>
<div class="bg-black">Black</div>
<!-- Color palette (50-950 for each color) -->
<div class="bg-red-500">Red</div>
<div class="bg-orange-500">Orange</div>
<div class="bg-amber-500">Amber</div>
<div class="bg-yellow-500">Yellow</div>
<div class="bg-lime-500">Lime</div>
<div class="bg-green-500">Green</div>
<div class="bg-emerald-500">Emerald</div>
<div class="bg-teal-500">Teal</div>
<div class="bg-cyan-500">Cyan</div>
<div class="bg-sky-500">Sky</div>
<div class="bg-blue-500">Blue</div>
<div class="bg-indigo-500">Indigo</div>
<div class="bg-violet-500">Violet</div>
<div class="bg-purple-500">Purple</div>
<div class="bg-fuchsia-500">Fuchsia</div>
<div class="bg-pink-500">Pink</div>
<div class="bg-rose-500">Rose</div>
<!-- With opacity -->
<div class="bg-blue-500/50">50% opacity</div>
<div class="bg-blue-500/75">75% opacity</div>
Text Colors
<p class="text-gray-900">Dark gray text</p>
<p class="text-blue-600">Blue text</p>
<p class="text-red-500">Red text</p>
<!-- With opacity -->
<p class="text-gray-900/50">Semi-transparent text</p>
Border Colors
<div class="border border-gray-300">Gray border</div>
<div class="border-2 border-blue-500">Blue border</div>
Borders
<!-- Border width -->
<div class="border">1px border</div>
<div class="border-0">No border</div>
<div class="border-2">2px border</div>
<div class="border-4">4px border</div>
<div class="border-8">8px border</div>
<!-- Individual sides -->
<div class="border-t">Top border</div>
<div class="border-r">Right border</div>
<div class="border-b">Bottom border</div>
<div class="border-l">Left border</div>
<!-- Border style -->
<div class="border border-solid">Solid</div>
<div class="border border-dashed">Dashed</div>
<div class="border border-dotted">Dotted</div>
<div class="border border-double">Double</div>
<!-- Border radius -->
<div class="rounded-none">No radius</div>
<div class="rounded-sm">Small radius</div>
<div class="rounded">Default radius (0.25rem)</div>
<div class="rounded-md">Medium radius</div>
<div class="rounded-lg">Large radius</div>
<div class="rounded-xl">Extra large radius</div>
<div class="rounded-2xl">2x large radius</div>
<div class="rounded-3xl">3x large radius</div>
<div class="rounded-full">Fully rounded (circle/pill)</div>
<!-- Individual corners -->
<div class="rounded-tl-lg">Top-left</div>
<div class="rounded-tr-lg">Top-right</div>
<div class="rounded-br-lg">Bottom-right</div>
<div class="rounded-bl-lg">Bottom-left</div>
<!-- Divide (borders between children) -->
<div class="divide-y divide-gray-200">
<div class="py-2">Item 1</div>
<div class="py-2">Item 2</div>
<div class="py-2">Item 3</div>
</div>
Effects and Filters
Box Shadow
<div class="shadow-sm">Small shadow</div>
<div class="shadow">Default shadow</div>
<div class="shadow-md">Medium shadow</div>
<div class="shadow-lg">Large shadow</div>
<div class="shadow-xl">Extra large shadow</div>
<div class="shadow-2xl">2x large shadow</div>
<div class="shadow-inner">Inner shadow</div>
<div class="shadow-none">No shadow</div>
<!-- Colored shadows -->
<div class="shadow-lg shadow-blue-500/50">Blue shadow</div>
Opacity
<div class="opacity-0">Invisible</div>
<div class="opacity-25">25% opacity</div>
<div class="opacity-50">50% opacity</div>
<div class="opacity-75">75% opacity</div>
<div class="opacity-100">Fully opaque</div>
Blur
<div class="blur-none">No blur</div>
<div class="blur-sm">Small blur</div>
<div class="blur">Default blur</div>
<div class="blur-lg">Large blur</div>
<div class="blur-xl">Extra large blur</div>
<!-- Backdrop blur (for overlays) -->
<div class="backdrop-blur-sm">Backdrop blur</div>
Other Filters
<!-- Brightness -->
<img class="brightness-50" src="image.jpg">
<img class="brightness-125" src="image.jpg">
<!-- Contrast -->
<img class="contrast-50" src="image.jpg">
<img class="contrast-150" src="image.jpg">
<!-- Grayscale -->
<img class="grayscale" src="image.jpg">
<!-- Sepia -->
<img class="sepia" src="image.jpg">
Transitions and Animations
<!-- Transition property -->
<button class="transition">All properties</button>
<button class="transition-colors">Colors only</button>
<button class="transition-opacity">Opacity only</button>
<button class="transition-transform">Transform only</button>
<!-- Duration -->
<button class="transition duration-150">150ms</button>
<button class="transition duration-300">300ms (default)</button>
<button class="transition duration-500">500ms</button>
<button class="transition duration-1000">1s</button>
<!-- Timing function -->
<button class="transition ease-linear">Linear</button>
<button class="transition ease-in">Ease in</button>
<button class="transition ease-out">Ease out</button>
<button class="transition ease-in-out">Ease in-out</button>
<!-- Complete transition example -->
<button class="bg-blue-500 hover:bg-blue-700 transition-colors duration-300">
Smooth color transition
</button>
<!-- Animations -->
<div class="animate-spin">Spinning</div>
<div class="animate-ping">Pinging</div>
<div class="animate-pulse">Pulsing</div>
<div class="animate-bounce">Bouncing</div>
Transforms
<!-- Scale -->
<img class="scale-50 hover:scale-100"> <!-- 50% to 100% on hover -->
<img class="scale-100 hover:scale-110"> <!-- Zoom in on hover -->
<img class="scale-x-75"> <!-- Scale X only -->
<!-- Rotate -->
<img class="rotate-0 hover:rotate-45"> <!-- 0 to 45 degrees -->
<img class="rotate-90"> <!-- 90 degrees -->
<img class="rotate-180"> <!-- 180 degrees -->
<img class="-rotate-45"> <!-- -45 degrees -->
<!-- Translate -->
<div class="translate-x-4">Move right 1rem</div>
<div class="translate-y-4">Move down 1rem</div>
<div class="-translate-x-1/2">Move left 50%</div>
<!-- Skew -->
<div class="skew-x-12">Skew X</div>
<div class="skew-y-6">Skew Y</div>
<!-- Transform origin -->
<div class="origin-center">Center origin (default)</div>
<div class="origin-top-left">Top-left origin</div>
<!-- Combined transforms with transition -->
<button class="transition-transform duration-300 hover:scale-110 hover:rotate-3">
Hover for effect
</button>
Responsive Design
Tailwind uses a mobile-first breakpoint system. Unprefixed utilities apply to all screen sizes, while prefixed utilities apply at the specified breakpoint and above.
Breakpoint System
// Default breakpoints
sm: '640px' // Small devices (landscape phones)
md: '768px' // Medium devices (tablets)
lg: '1024px' // Large devices (desktops)
xl: '1280px' // Extra large devices (large desktops)
2xl: '1536px' // 2x extra large devices
Responsive Utilities
<!-- Mobile: full width, Desktop: half width -->
<div class="w-full lg:w-1/2">
Responsive width
</div>
<!-- Hide on mobile, show on desktop -->
<div class="hidden lg:block">
Desktop only content
</div>
<!-- Responsive padding -->
<div class="p-4 md:p-6 lg:p-8">
Increasing padding
</div>
<!-- Responsive grid -->
<div class="grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4 gap-4">
<div>Item 1</div>
<div>Item 2</div>
<div>Item 3</div>
<div>Item 4</div>
</div>
Responsive Layout Example
<!-- Mobile: stacked, Desktop: side-by-side -->
<div class="flex flex-col lg:flex-row gap-4">
<!-- Sidebar: full width mobile, 1/4 width desktop -->
<aside class="w-full lg:w-1/4 bg-gray-100 p-4">
Sidebar
</aside>
<!-- Main: full width mobile, 3/4 width desktop -->
<main class="w-full lg:w-3/4 p-4">
Main content
</main>
</div>
Responsive Typography
<h1 class="text-2xl sm:text-3xl md:text-4xl lg:text-5xl xl:text-6xl font-bold">
Responsive heading
</h1>
<p class="text-sm md:text-base lg:text-lg">
Responsive paragraph
</p>
Custom Breakpoints
// tailwind.config.js
module.exports = {
theme: {
screens: {
'sm': '640px',
'md': '768px',
'lg': '1024px',
'xl': '1280px',
'2xl': '1536px',
'3xl': '1920px', // Custom breakpoint
},
},
}
<div class="hidden 3xl:block">
Only on 1920px+ screens
</div>
Container Queries (Plugin)
npm install @tailwindcss/container-queries
// tailwind.config.js
module.exports = {
plugins: [
require('@tailwindcss/container-queries'),
],
}
<div class="@container">
<div class="@md:text-2xl @lg:text-4xl">
Size based on container, not viewport
</div>
</div>
State Variants
Tailwind includes variants for styling elements based on their state.
Hover, Focus, and Active
<!-- Hover -->
<button class="bg-blue-500 hover:bg-blue-700">
Hover me
</button>
<!-- Focus -->
<input class="border border-gray-300 focus:border-blue-500 focus:ring-2 focus:ring-blue-200">
<!-- Active (being clicked) -->
<button class="bg-blue-500 active:bg-blue-800">
Click me
</button>
<!-- Combined states -->
<button class="
bg-blue-500
hover:bg-blue-600
focus:ring-2
focus:ring-blue-300
active:bg-blue-700
transition-colors
">
Full interaction states
</button>
Focus Visible
<!-- Only show focus ring for keyboard navigation -->
<button class="focus:outline-none focus-visible:ring-2 focus-visible:ring-blue-500">
Keyboard accessible
</button>
Form States
<!-- Disabled -->
<button class="bg-blue-500 disabled:bg-gray-300 disabled:cursor-not-allowed" disabled>
Disabled button
</button>
<!-- Required -->
<input class="border required:border-red-500" required>
<!-- Valid/Invalid -->
<input class="border invalid:border-red-500 valid:border-green-500" type="email">
<!-- Placeholder -->
<input class="placeholder:italic placeholder:text-gray-400" placeholder="Email address">
Group Hover and Focus
Style child elements when hovering over parent:
<div class="group hover:bg-blue-50 p-4 cursor-pointer">
<h3 class="group-hover:text-blue-600">Heading</h3>
<p class="group-hover:text-gray-700">
Hover over the card to change colors
</p>
<button class="opacity-0 group-hover:opacity-100">
Hidden button appears on card hover
</button>
</div>
<!-- Group with custom name -->
<div class="group/card hover:bg-blue-50">
<div class="group/item">
<p class="group-hover/card:text-blue-600">Card hover</p>
<p class="group-hover/item:text-red-600">Item hover</p>
</div>
</div>
Peer Modifiers
Style an element based on sibling state:
<label>
<input type="checkbox" class="peer sr-only">
<div class="
w-11 h-6 bg-gray-200 rounded-full
peer-checked:bg-blue-600
peer-focus:ring-2 peer-focus:ring-blue-300
">
<!-- Toggle switch styled by peer checkbox -->
</div>
</label>
<!-- Floating label -->
<div class="relative">
<input
id="email"
class="peer w-full border-b-2 border-gray-300 focus:border-blue-500"
placeholder=" "
>
<label
for="email"
class="
absolute left-0 top-0
text-gray-500
peer-placeholder-shown:top-2
peer-focus:top-0
peer-focus:text-xs
peer-focus:text-blue-500
transition-all
"
>
Email
</label>
</div>
Child Selectors
<!-- First and last child -->
<ul>
<li class="first:font-bold">First (bold)</li>
<li>Middle</li>
<li class="last:font-bold">Last (bold)</li>
</ul>
<!-- Odd and even -->
<table>
<tr class="odd:bg-white even:bg-gray-50">
<td>Row 1</td>
</tr>
<tr class="odd:bg-white even:bg-gray-50">
<td>Row 2</td>
</tr>
</table>
Before and After Pseudo-elements
<!-- Before -->
<div class="
before:content-['→']
before:mr-2
before:text-blue-500
">
Content with arrow before
</div>
<!-- After -->
<a class="
after:content-['_↗']
after:text-xs
after:text-gray-400
">
External link
</a>
Dark Mode
Tailwind includes first-class dark mode support.
Configuration
// tailwind.config.js
module.exports = {
// Choose strategy
darkMode: 'class', // or 'media'
// ...
}
Two strategies:
- 'media': Uses
prefers-color-schememedia query (system preference) - 'class': Requires
.darkclass on<html>or<body>(manual toggle)
Using Dark Mode (Class Strategy)
<!-- Light mode: white background, dark text -->
<!-- Dark mode: dark background, light text -->
<div class="bg-white dark:bg-gray-900 text-gray-900 dark:text-white">
Content adapts to dark mode
</div>
Dark Mode Examples
<!-- Card with dark mode -->
<div class="bg-white dark:bg-gray-800 rounded-lg shadow-lg p-6">
<h2 class="text-gray-900 dark:text-white text-2xl font-bold">
Heading
</h2>
<p class="text-gray-700 dark:text-gray-300">
Description text
</p>
<button class="
bg-blue-500 hover:bg-blue-600
dark:bg-blue-600 dark:hover:bg-blue-700
text-white
">
Button
</button>
</div>
<!-- Form input -->
<input class="
bg-white dark:bg-gray-700
border border-gray-300 dark:border-gray-600
text-gray-900 dark:text-white
focus:border-blue-500 dark:focus:border-blue-400
focus:ring-2 focus:ring-blue-200 dark:focus:ring-blue-800
">
<!-- Image with different versions -->
<img
class="block dark:hidden"
src="logo-light.png"
alt="Logo"
>
<img
class="hidden dark:block"
src="logo-dark.png"
alt="Logo"
>
Dark Mode Toggle Implementation
<!-- HTML -->
<button id="theme-toggle" class="p-2 rounded-lg bg-gray-200 dark:bg-gray-700">
<!-- Sun icon (show in dark mode) -->
<svg class="hidden dark:block w-6 h-6" fill="currentColor" viewBox="0 0 20 20">
<path d="M10 2a1 1 0 011 1v1a1 1 0 11-2 0V3a1 1 0 011-1zm4 8a4 4 0 11-8 0 4 4 0 018 0zm-.464 4.95l.707.707a1 1 0 001.414-1.414l-.707-.707a1 1 0 00-1.414 1.414zm2.12-10.607a1 1 0 010 1.414l-.706.707a1 1 0 11-1.414-1.414l.707-.707a1 1 0 011.414 0zM17 11a1 1 0 100-2h-1a1 1 0 100 2h1zm-7 4a1 1 0 011 1v1a1 1 0 11-2 0v-1a1 1 0 011-1zM5.05 6.464A1 1 0 106.465 5.05l-.708-.707a1 1 0 00-1.414 1.414l.707.707zm1.414 8.486l-.707.707a1 1 0 01-1.414-1.414l.707-.707a1 1 0 011.414 1.414zM4 11a1 1 0 100-2H3a1 1 0 000 2h1z"></path>
</svg>
<!-- Moon icon (show in light mode) -->
<svg class="block dark:hidden w-6 h-6" fill="currentColor" viewBox="0 0 20 20">
<path d="M17.293 13.293A8 8 0 016.707 2.707a8.001 8.001 0 1010.586 10.586z"></path>
</svg>
</button>
<script>
// JavaScript for toggle
const toggle = document.getElementById('theme-toggle');
const html = document.documentElement;
// Check localStorage or system preference
if (localStorage.theme === 'dark' ||
(!('theme' in localStorage) &&
window.matchMedia('(prefers-color-scheme: dark)').matches)) {
html.classList.add('dark');
} else {
html.classList.remove('dark');
}
toggle.addEventListener('click', () => {
if (html.classList.contains('dark')) {
html.classList.remove('dark');
localStorage.theme = 'light';
} else {
html.classList.add('dark');
localStorage.theme = 'dark';
}
});
</script>
React Dark Mode Toggle
import { useState, useEffect } from 'react';
function DarkModeToggle() {
const [darkMode, setDarkMode] = useState(false);
useEffect(() => {
// Check localStorage or system preference
const isDark = localStorage.theme === 'dark' ||
(!('theme' in localStorage) &&
window.matchMedia('(prefers-color-scheme: dark)').matches);
setDarkMode(isDark);
if (isDark) {
document.documentElement.classList.add('dark');
}
}, []);
const toggleDarkMode = () => {
setDarkMode(!darkMode);
if (!darkMode) {
document.documentElement.classList.add('dark');
localStorage.theme = 'dark';
} else {
document.documentElement.classList.remove('dark');
localStorage.theme = 'light';
}
};
return (
<button
onClick={toggleDarkMode}
className="p-2 rounded-lg bg-gray-200 dark:bg-gray-700"
>
{darkMode ? '☀️' : '🌙'}
</button>
);
}
Component Patterns
Building real-world components with Tailwind utilities.
Buttons
<!-- Primary button -->
<button class="
px-4 py-2
bg-blue-600 hover:bg-blue-700
active:bg-blue-800
text-white font-medium
rounded-lg
transition-colors
focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2
">
Primary Button
</button>
<!-- Secondary button -->
<button class="
px-4 py-2
bg-gray-200 hover:bg-gray-300
text-gray-900 font-medium
rounded-lg
transition-colors
">
Secondary Button
</button>
<!-- Outline button -->
<button class="
px-4 py-2
border-2 border-blue-600
text-blue-600 hover:bg-blue-50
font-medium rounded-lg
transition-colors
">
Outline Button
</button>
<!-- Ghost button -->
<button class="
px-4 py-2
text-blue-600 hover:bg-blue-50
font-medium rounded-lg
transition-colors
">
Ghost Button
</button>
<!-- Danger button -->
<button class="
px-4 py-2
bg-red-600 hover:bg-red-700
text-white font-medium
rounded-lg
">
Delete
</button>
<!-- Disabled button -->
<button
class="
px-4 py-2
bg-blue-600
text-white font-medium
rounded-lg
disabled:bg-gray-300 disabled:cursor-not-allowed
"
disabled
>
Disabled
</button>
<!-- Button sizes -->
<button class="px-2 py-1 text-sm bg-blue-600 text-white rounded">Small</button>
<button class="px-4 py-2 text-base bg-blue-600 text-white rounded-lg">Medium</button>
<button class="px-6 py-3 text-lg bg-blue-600 text-white rounded-lg">Large</button>
<button class="px-8 py-4 text-xl bg-blue-600 text-white rounded-xl">XL</button>
<!-- Icon button -->
<button class="p-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700">
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 4v16m8-8H4"></path>
</svg>
</button>
<!-- Button with icon -->
<button class="
flex items-center gap-2
px-4 py-2
bg-blue-600 hover:bg-blue-700
text-white rounded-lg
">
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 4v16m8-8H4"></path>
</svg>
Add Item
</button>
<!-- Loading button -->
<button class="
flex items-center gap-2
px-4 py-2
bg-blue-600
text-white rounded-lg
cursor-wait
" disabled>
<svg class="animate-spin h-5 w-5" viewBox="0 0 24 24">
<circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4" fill="none"></circle>
<path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
</svg>
Loading...
</button>
<!-- Button group -->
<div class="inline-flex rounded-lg shadow-sm">
<button class="px-4 py-2 bg-white border border-gray-300 rounded-l-lg hover:bg-gray-50">
Left
</button>
<button class="px-4 py-2 bg-white border-t border-b border-gray-300 hover:bg-gray-50">
Middle
</button>
<button class="px-4 py-2 bg-white border border-gray-300 rounded-r-lg hover:bg-gray-50">
Right
</button>
</div>
Cards
<!-- Basic card -->
<div class="bg-white rounded-lg shadow-md p-6">
<h3 class="text-xl font-bold mb-2">Card Title</h3>
<p class="text-gray-700">
This is a simple card component with rounded corners and shadow.
</p>
</div>
<!-- Product card -->
<div class="group bg-white rounded-lg shadow-md overflow-hidden hover:shadow-xl transition-shadow">
<!-- Image -->
<div class="relative overflow-hidden">
<img
src="product.jpg"
alt="Product"
class="w-full h-48 object-cover group-hover:scale-110 transition-transform duration-300"
>
<!-- Badge -->
<span class="absolute top-2 right-2 bg-red-500 text-white text-xs font-bold px-2 py-1 rounded">
SALE
</span>
</div>
<!-- Content -->
<div class="p-4">
<h3 class="text-lg font-semibold mb-2 group-hover:text-blue-600 transition-colors">
Product Name
</h3>
<p class="text-gray-600 text-sm mb-4">
Product description goes here
</p>
<!-- Price and button -->
<div class="flex items-center justify-between">
<div>
<span class="text-gray-400 line-through text-sm">$99.00</span>
<span class="text-2xl font-bold text-gray-900 ml-2">$79.00</span>
</div>
<button class="px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700">
Add to Cart
</button>
</div>
</div>
</div>
<!-- Profile card -->
<div class="bg-white rounded-xl shadow-lg p-6 max-w-sm">
<!-- Avatar -->
<div class="flex items-center gap-4 mb-4">
<img
src="avatar.jpg"
alt="Profile"
class="w-16 h-16 rounded-full object-cover"
>
<div>
<h3 class="text-lg font-bold text-gray-900">John Doe</h3>
<p class="text-gray-500 text-sm">Software Engineer</p>
</div>
</div>
<!-- Bio -->
<p class="text-gray-700 mb-4">
Passionate about building great user experiences with modern web technologies.
</p>
<!-- Stats -->
<div class="flex gap-4 mb-4">
<div class="text-center">
<div class="text-2xl font-bold text-gray-900">1.2K</div>
<div class="text-gray-500 text-sm">Followers</div>
</div>
<div class="text-center">
<div class="text-2xl font-bold text-gray-900">456</div>
<div class="text-gray-500 text-sm">Following</div>
</div>
<div class="text-center">
<div class="text-2xl font-bold text-gray-900">89</div>
<div class="text-gray-500 text-sm">Posts</div>
</div>
</div>
<!-- Actions -->
<div class="flex gap-2">
<button class="flex-1 px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700">
Follow
</button>
<button class="px-4 py-2 border border-gray-300 rounded-lg hover:bg-gray-50">
Message
</button>
</div>
</div>
<!-- Stats card with icon -->
<div class="bg-white rounded-lg shadow-md p-6">
<div class="flex items-center justify-between mb-4">
<div>
<p class="text-gray-500 text-sm font-medium">Total Revenue</p>
<p class="text-3xl font-bold text-gray-900">$45,231</p>
</div>
<div class="p-3 bg-green-100 rounded-full">
<svg class="w-8 h-8 text-green-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 8c-1.657 0-3 .895-3 2s1.343 2 3 2 3 .895 3 2-1.343 2-3 2m0-8c1.11 0 2.08.402 2.599 1M12 8V7m0 1v8m0 0v1m0-1c-1.11 0-2.08-.402-2.599-1M21 12a9 9 0 11-18 0 9 9 0 0118 0z"></path>
</svg>
</div>
</div>
<div class="flex items-center gap-1 text-sm">
<span class="text-green-600 font-medium">↑ 12%</span>
<span class="text-gray-500">from last month</span>
</div>
</div>
Forms
<!-- Complete form -->
<form class="max-w-md mx-auto bg-white rounded-lg shadow-md p-6">
<h2 class="text-2xl font-bold mb-6">Sign Up</h2>
<!-- Text input -->
<div class="mb-4">
<label class="block text-gray-700 font-medium mb-2" for="name">
Full Name
</label>
<input
id="name"
type="text"
class="
w-full px-4 py-2
border border-gray-300 rounded-lg
focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent
placeholder:text-gray-400
"
placeholder="John Doe"
>
</div>
<!-- Email input with validation states -->
<div class="mb-4">
<label class="block text-gray-700 font-medium mb-2" for="email">
Email
</label>
<input
id="email"
type="email"
class="
w-full px-4 py-2
border rounded-lg
focus:outline-none focus:ring-2 focus:ring-blue-500
invalid:border-red-500 invalid:ring-red-500
valid:border-green-500
"
placeholder="john@example.com"
required
>
<p class="mt-1 text-sm text-red-600 hidden peer-invalid:block">
Please enter a valid email
</p>
</div>
<!-- Password input -->
<div class="mb-4">
<label class="block text-gray-700 font-medium mb-2" for="password">
Password
</label>
<input
id="password"
type="password"
class="
w-full px-4 py-2
border border-gray-300 rounded-lg
focus:outline-none focus:ring-2 focus:ring-blue-500
"
required
>
</div>
<!-- Select -->
<div class="mb-4">
<label class="block text-gray-700 font-medium mb-2" for="country">
Country
</label>
<select
id="country"
class="
w-full px-4 py-2
border border-gray-300 rounded-lg
focus:outline-none focus:ring-2 focus:ring-blue-500
bg-white
"
>
<option>United States</option>
<option>Canada</option>
<option>United Kingdom</option>
<option>Australia</option>
</select>
</div>
<!-- Textarea -->
<div class="mb-4">
<label class="block text-gray-700 font-medium mb-2" for="bio">
Bio
</label>
<textarea
id="bio"
rows="4"
class="
w-full px-4 py-2
border border-gray-300 rounded-lg
focus:outline-none focus:ring-2 focus:ring-blue-500
resize-none
"
placeholder="Tell us about yourself..."
></textarea>
</div>
<!-- Checkbox -->
<div class="mb-4">
<label class="flex items-center">
<input
type="checkbox"
class="
w-4 h-4
text-blue-600
border-gray-300 rounded
focus:ring-2 focus:ring-blue-500
"
>
<span class="ml-2 text-gray-700">I agree to the Terms and Conditions</span>
</label>
</div>
<!-- Radio buttons -->
<div class="mb-6">
<p class="text-gray-700 font-medium mb-2">Newsletter</p>
<label class="flex items-center mb-2">
<input
type="radio"
name="newsletter"
value="daily"
class="w-4 h-4 text-blue-600 focus:ring-2 focus:ring-blue-500"
>
<span class="ml-2 text-gray-700">Daily</span>
</label>
<label class="flex items-center mb-2">
<input
type="radio"
name="newsletter"
value="weekly"
class="w-4 h-4 text-blue-600 focus:ring-2 focus:ring-blue-500"
checked
>
<span class="ml-2 text-gray-700">Weekly</span>
</label>
<label class="flex items-center">
<input
type="radio"
name="newsletter"
value="never"
class="w-4 h-4 text-blue-600 focus:ring-2 focus:ring-blue-500"
>
<span class="ml-2 text-gray-700">Never</span>
</label>
</div>
<!-- Submit button -->
<button
type="submit"
class="
w-full px-4 py-2
bg-blue-600 hover:bg-blue-700
text-white font-medium rounded-lg
transition-colors
focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2
"
>
Create Account
</button>
</form>
<!-- File upload -->
<div class="max-w-md mx-auto">
<label class="
flex flex-col items-center justify-center
w-full h-32
border-2 border-gray-300 border-dashed rounded-lg
cursor-pointer
hover:bg-gray-50
transition-colors
">
<div class="flex flex-col items-center justify-center pt-5 pb-6">
<svg class="w-10 h-10 mb-3 text-gray-400" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M7 16a4 4 0 01-.88-7.903A5 5 0 1115.9 6L16 6a5 5 0 011 9.9M15 13l-3-3m0 0l-3 3m3-3v12"></path>
</svg>
<p class="mb-2 text-sm text-gray-500">
<span class="font-semibold">Click to upload</span> or drag and drop
</p>
<p class="text-xs text-gray-500">PNG, JPG or GIF (MAX. 800x400px)</p>
</div>
<input type="file" class="hidden">
</label>
</div>
Navigation
<!-- Desktop navbar -->
<nav class="bg-white shadow-lg">
<div class="container mx-auto px-4">
<div class="flex items-center justify-between h-16">
<!-- Logo -->
<div class="flex items-center">
<a href="/" class="text-xl font-bold text-gray-900">
Logo
</a>
</div>
<!-- Desktop menu -->
<div class="hidden md:flex items-center space-x-4">
<a href="#" class="text-gray-700 hover:text-blue-600 px-3 py-2 rounded-md font-medium">
Home
</a>
<a href="#" class="text-gray-700 hover:text-blue-600 px-3 py-2 rounded-md font-medium">
About
</a>
<a href="#" class="text-gray-700 hover:text-blue-600 px-3 py-2 rounded-md font-medium">
Services
</a>
<a href="#" class="text-gray-700 hover:text-blue-600 px-3 py-2 rounded-md font-medium">
Contact
</a>
<button class="px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700">
Sign In
</button>
</div>
<!-- Mobile menu button -->
<div class="md:hidden">
<button class="p-2 rounded-md text-gray-700 hover:bg-gray-100">
<svg class="w-6 h-6" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 6h16M4 12h16M4 18h16"></path>
</svg>
</button>
</div>
</div>
</div>
<!-- Mobile menu (hidden by default) -->
<div class="md:hidden hidden">
<div class="px-2 pt-2 pb-3 space-y-1">
<a href="#" class="block px-3 py-2 rounded-md text-gray-700 hover:bg-gray-100">Home</a>
<a href="#" class="block px-3 py-2 rounded-md text-gray-700 hover:bg-gray-100">About</a>
<a href="#" class="block px-3 py-2 rounded-md text-gray-700 hover:bg-gray-100">Services</a>
<a href="#" class="block px-3 py-2 rounded-md text-gray-700 hover:bg-gray-100">Contact</a>
</div>
</div>
</nav>
<!-- Sidebar navigation -->
<aside class="w-64 bg-gray-900 min-h-screen">
<div class="p-4">
<h2 class="text-white text-xl font-bold mb-6">Dashboard</h2>
<nav class="space-y-2">
<a href="#" class="flex items-center gap-3 px-4 py-2 bg-blue-600 text-white rounded-lg">
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M3 12l2-2m0 0l7-7 7 7M5 10v10a1 1 0 001 1h3m10-11l2 2m-2-2v10a1 1 0 01-1 1h-3m-6 0a1 1 0 001-1v-4a1 1 0 011-1h2a1 1 0 011 1v4a1 1 0 001 1m-6 0h6"></path>
</svg>
Dashboard
</a>
<a href="#" class="flex items-center gap-3 px-4 py-2 text-gray-300 hover:bg-gray-800 rounded-lg">
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M16 7a4 4 0 11-8 0 4 4 0 018 0zM12 14a7 7 0 00-7 7h14a7 7 0 00-7-7z"></path>
</svg>
Users
</a>
<a href="#" class="flex items-center gap-3 px-4 py-2 text-gray-300 hover:bg-gray-800 rounded-lg">
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M10.325 4.317c.426-1.756 2.924-1.756 3.35 0a1.724 1.724 0 002.573 1.066c1.543-.94 3.31.826 2.37 2.37a1.724 1.724 0 001.065 2.572c1.756.426 1.756 2.924 0 3.35a1.724 1.724 0 00-1.066 2.573c.94 1.543-.826 3.31-2.37 2.37a1.724 1.724 0 00-2.572 1.065c-.426 1.756-2.924 1.756-3.35 0a1.724 1.724 0 00-2.573-1.066c-1.543.94-3.31-.826-2.37-2.37a1.724 1.724 0 00-1.065-2.572c-1.756-.426-1.756-2.924 0-3.35a1.724 1.724 0 001.066-2.573c-.94-1.543.826-3.31 2.37-2.37.996.608 2.296.07 2.572-1.065z"></path>
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M15 12a3 3 0 11-6 0 3 3 0 016 0z"></path>
</svg>
Settings
</a>
</nav>
</div>
</aside>
<!-- Breadcrumbs -->
<nav class="flex" aria-label="Breadcrumb">
<ol class="inline-flex items-center space-x-1 md:space-x-3">
<li class="inline-flex items-center">
<a href="#" class="text-gray-700 hover:text-blue-600">
Home
</a>
</li>
<li>
<div class="flex items-center">
<svg class="w-6 h-6 text-gray-400" fill="currentColor" viewBox="0 0 20 20">
<path fill-rule="evenodd" d="M7.293 14.707a1 1 0 010-1.414L10.586 10 7.293 6.707a1 1 0 011.414-1.414l4 4a1 1 0 010 1.414l-4 4a1 1 0 01-1.414 0z" clip-rule="evenodd"></path>
</svg>
<a href="#" class="ml-1 text-gray-700 hover:text-blue-600">
Products
</a>
</div>
</li>
<li>
<div class="flex items-center">
<svg class="w-6 h-6 text-gray-400" fill="currentColor" viewBox="0 0 20 20">
<path fill-rule="evenodd" d="M7.293 14.707a1 1 0 010-1.414L10.586 10 7.293 6.707a1 1 0 011.414-1.414l4 4a1 1 0 010 1.414l-4 4a1 1 0 01-1.414 0z" clip-rule="evenodd"></path>
</svg>
<span class="ml-1 text-gray-500">Details</span>
</div>
</li>
</ol>
</nav>
<!-- Tabs -->
<div class="border-b border-gray-200">
<nav class="flex space-x-8">
<a href="#" class="border-b-2 border-blue-500 text-blue-600 py-4 px-1 font-medium">
Profile
</a>
<a href="#" class="border-b-2 border-transparent text-gray-500 hover:text-gray-700 hover:border-gray-300 py-4 px-1 font-medium">
Settings
</a>
<a href="#" class="border-b-2 border-transparent text-gray-500 hover:text-gray-700 hover:border-gray-300 py-4 px-1 font-medium">
Notifications
</a>
</nav>
</div>
Modals and Overlays
<!-- Modal -->
<div class="fixed inset-0 bg-gray-600 bg-opacity-50 flex items-center justify-center p-4 z-50">
<!-- Modal content -->
<div class="bg-white rounded-lg shadow-xl max-w-md w-full">
<!-- Header -->
<div class="flex items-center justify-between p-6 border-b">
<h3 class="text-xl font-semibold text-gray-900">
Modal Title
</h3>
<button class="text-gray-400 hover:text-gray-600">
<svg class="w-6 h-6" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M6 18L18 6M6 6l12 12"></path>
</svg>
</button>
</div>
<!-- Body -->
<div class="p-6">
<p class="text-gray-700">
This is the modal content. You can add any content here.
</p>
</div>
<!-- Footer -->
<div class="flex justify-end gap-3 p-6 border-t">
<button class="px-4 py-2 text-gray-700 border border-gray-300 rounded-lg hover:bg-gray-50">
Cancel
</button>
<button class="px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700">
Confirm
</button>
</div>
</div>
</div>
<!-- Dropdown menu -->
<div class="relative inline-block text-left">
<button class="flex items-center gap-2 px-4 py-2 bg-white border border-gray-300 rounded-lg hover:bg-gray-50">
Options
<svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7"></path>
</svg>
</button>
<!-- Dropdown panel -->
<div class="absolute right-0 mt-2 w-56 bg-white rounded-lg shadow-lg ring-1 ring-black ring-opacity-5 z-10">
<div class="py-1">
<a href="#" class="block px-4 py-2 text-gray-700 hover:bg-gray-100">
Edit
</a>
<a href="#" class="block px-4 py-2 text-gray-700 hover:bg-gray-100">
Duplicate
</a>
<hr class="my-1">
<a href="#" class="block px-4 py-2 text-red-600 hover:bg-gray-100">
Delete
</a>
</div>
</div>
</div>
<!-- Toast notification -->
<div class="fixed top-4 right-4 bg-white rounded-lg shadow-lg p-4 max-w-sm animate-slide-in">
<div class="flex items-start gap-3">
<!-- Success icon -->
<div class="flex-shrink-0">
<svg class="w-6 h-6 text-green-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z"></path>
</svg>
</div>
<!-- Content -->
<div class="flex-1">
<p class="font-medium text-gray-900">Success!</p>
<p class="text-sm text-gray-500">Your changes have been saved.</p>
</div>
<!-- Close button -->
<button class="flex-shrink-0 text-gray-400 hover:text-gray-600">
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M6 18L18 6M6 6l12 12"></path>
</svg>
</button>
</div>
</div>
Alerts and Badges
<!-- Alert variants -->
<div class="bg-blue-50 border border-blue-200 text-blue-800 px-4 py-3 rounded-lg" role="alert">
<strong class="font-bold">Info!</strong>
<span class="block sm:inline"> This is an informational message.</span>
</div>
<div class="bg-green-50 border border-green-200 text-green-800 px-4 py-3 rounded-lg" role="alert">
<strong class="font-bold">Success!</strong>
<span class="block sm:inline"> Operation completed successfully.</span>
</div>
<div class="bg-yellow-50 border border-yellow-200 text-yellow-800 px-4 py-3 rounded-lg" role="alert">
<strong class="font-bold">Warning!</strong>
<span class="block sm:inline"> Please review before proceeding.</span>
</div>
<div class="bg-red-50 border border-red-200 text-red-800 px-4 py-3 rounded-lg" role="alert">
<strong class="font-bold">Error!</strong>
<span class="block sm:inline"> Something went wrong.</span>
</div>
<!-- Badges -->
<span class="px-2 py-1 text-xs font-semibold bg-gray-200 text-gray-800 rounded-full">
Default
</span>
<span class="px-2 py-1 text-xs font-semibold bg-blue-100 text-blue-800 rounded-full">
Primary
</span>
<span class="px-2 py-1 text-xs font-semibold bg-green-100 text-green-800 rounded-full">
Success
</span>
<span class="px-2 py-1 text-xs font-semibold bg-red-100 text-red-800 rounded-full">
Danger
</span>
<!-- Badge with dot -->
<span class="inline-flex items-center gap-1 px-2 py-1 text-xs font-semibold bg-green-100 text-green-800 rounded-full">
<span class="w-2 h-2 bg-green-500 rounded-full"></span>
Active
</span>
<!-- Loading skeleton -->
<div class="animate-pulse">
<div class="h-4 bg-gray-200 rounded w-3/4 mb-2"></div>
<div class="h-4 bg-gray-200 rounded w-1/2 mb-2"></div>
<div class="h-4 bg-gray-200 rounded w-5/6"></div>
</div>
<!-- Progress bar -->
<div class="w-full bg-gray-200 rounded-full h-2.5">
<div class="bg-blue-600 h-2.5 rounded-full" style="width: 45%"></div>
</div>
Layout Patterns
Dashboard Layout
<div class="min-h-screen bg-gray-100">
<!-- Sidebar -->
<aside class="fixed inset-y-0 left-0 w-64 bg-gray-900">
<!-- Sidebar content here -->
</aside>
<!-- Main content -->
<div class="ml-64">
<!-- Header -->
<header class="bg-white shadow-sm sticky top-0 z-10">
<div class="px-6 py-4">
<h1 class="text-2xl font-bold">Dashboard</h1>
</div>
</header>
<!-- Content -->
<main class="p-6">
<!-- Grid of cards/widgets -->
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-6">
<!-- Cards here -->
</div>
</main>
</div>
</div>
Landing Page Hero
<section class="relative bg-gradient-to-r from-blue-600 to-indigo-700 text-white">
<div class="container mx-auto px-4 py-20 md:py-32">
<div class="max-w-3xl mx-auto text-center">
<h1 class="text-4xl md:text-5xl lg:text-6xl font-bold mb-6">
Build Amazing Products
</h1>
<p class="text-xl md:text-2xl mb-8 text-blue-100">
The fastest way to create beautiful, responsive websites
</p>
<div class="flex flex-col sm:flex-row gap-4 justify-center">
<button class="px-8 py-3 bg-white text-blue-600 rounded-lg font-semibold hover:bg-gray-100">
Get Started
</button>
<button class="px-8 py-3 border-2 border-white rounded-lg font-semibold hover:bg-white hover:text-blue-600 transition-colors">
Learn More
</button>
</div>
</div>
</div>
</section>
Centering Techniques
<!-- Flexbox centering -->
<div class="flex items-center justify-center min-h-screen">
<div class="text-center">
Perfectly centered
</div>
</div>
<!-- Grid centering -->
<div class="grid place-items-center min-h-screen">
<div>
Centered with grid
</div>
</div>
<!-- Absolute centering -->
<div class="relative h-screen">
<div class="absolute top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2">
Centered with transform
</div>
</div>
Holy Grail Layout
<div class="min-h-screen flex flex-col">
<!-- Header -->
<header class="bg-gray-800 text-white p-4">
Header
</header>
<!-- Main content area -->
<div class="flex flex-1">
<!-- Left sidebar -->
<aside class="w-64 bg-gray-100 p-4">
Left Sidebar
</aside>
<!-- Main content -->
<main class="flex-1 p-4">
Main Content
</main>
<!-- Right sidebar -->
<aside class="w-64 bg-gray-100 p-4">
Right Sidebar
</aside>
</div>
<!-- Footer -->
<footer class="bg-gray-800 text-white p-4">
Footer
</footer>
</div>
Sticky Header/Footer
<div class="min-h-screen flex flex-col">
<!-- Sticky header -->
<header class="sticky top-0 bg-white shadow-md p-4 z-10">
Sticky Header
</header>
<!-- Main content (scrollable) -->
<main class="flex-1 p-4">
<!-- Long content here -->
</main>
<!-- Sticky footer -->
<footer class="sticky bottom-0 bg-gray-800 text-white p-4">
Sticky Footer
</footer>
</div>
Customization
Extending Colors
// tailwind.config.js
module.exports = {
theme: {
extend: {
colors: {
// Brand colors
brand: {
50: '#eff6ff',
100: '#dbeafe',
500: '#3b82f6',
900: '#1e3a8a',
},
// Single color
'accent': '#ff6b6b',
},
},
},
}
<!-- Use custom colors -->
<div class="bg-brand-500 text-white">Brand color</div>
<div class="bg-accent text-white">Accent color</div>
Extending Spacing
module.exports = {
theme: {
extend: {
spacing: {
'128': '32rem',
'144': '36rem',
},
},
},
}
<div class="p-128">Extra large padding</div>
Custom Fonts
module.exports = {
theme: {
extend: {
fontFamily: {
sans: ['Inter', 'sans-serif'],
display: ['Lexend', 'sans-serif'],
body: ['Open Sans', 'sans-serif'],
},
},
},
}
/* In your CSS */
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
<h1 class="font-display">Display font</h1>
<p class="font-body">Body font</p>
Arbitrary Values
Use square brackets for one-off custom values:
<!-- Custom width -->
<div class="w-[347px]">Exact width</div>
<!-- Custom color -->
<div class="bg-[#1da1f2]">Twitter blue</div>
<!-- Custom grid -->
<div class="grid-cols-[200px_1fr_200px]">Custom grid</div>
<!-- Custom shadow -->
<div class="shadow-[0_35px_60px_-15px_rgba(0,0,0,0.3)]">Custom shadow</div>
Adding Custom Utilities
// tailwind.config.js
const plugin = require('tailwindcss/plugin')
module.exports = {
plugins: [
plugin(function({ addUtilities }) {
const newUtilities = {
'.text-shadow': {
textShadow: '2px 2px 4px rgba(0,0,0,0.1)',
},
'.text-shadow-lg': {
textShadow: '4px 4px 8px rgba(0,0,0,0.2)',
},
}
addUtilities(newUtilities)
})
],
}
<h1 class="text-shadow">Text with shadow</h1>
Plugin System
Official Plugins
@tailwindcss/forms
Provides better default styles for form elements.
npm install @tailwindcss/forms
// tailwind.config.js
module.exports = {
plugins: [
require('@tailwindcss/forms'),
],
}
<!-- Forms are automatically styled nicely -->
<input type="text" class="mt-1 block w-full">
<select class="mt-1 block w-full">
<option>Option 1</option>
</select>
@tailwindcss/typography
Adds prose class for styling user-generated content.
npm install @tailwindcss/typography
module.exports = {
plugins: [
require('@tailwindcss/typography'),
],
}
<article class="prose lg:prose-xl">
<!-- All HTML elements are beautifully styled -->
<h1>Heading</h1>
<p>Paragraph with nice defaults</p>
<ul>
<li>List item</li>
</ul>
</article>
<!-- Dark mode -->
<article class="prose dark:prose-invert">
Content
</article>
@tailwindcss/aspect-ratio
Maintains aspect ratios for elements.
npm install @tailwindcss/aspect-ratio
<div class="aspect-w-16 aspect-h-9">
<iframe src="video.mp4"></iframe>
</div>
@tailwindcss/container-queries
Enables container-based responsive design.
npm install @tailwindcss/container-queries
<div class="@container">
<div class="@lg:text-xl">
Responds to container size, not viewport
</div>
</div>
Creating Custom Plugins
// tailwind.config.js
const plugin = require('tailwindcss/plugin')
module.exports = {
plugins: [
// Simple utility plugin
plugin(function({ addUtilities }) {
addUtilities({
'.rotate-y-180': {
transform: 'rotateY(180deg)',
},
})
}),
// Plugin with options
plugin(function({ addComponents, theme }) {
addComponents({
'.btn': {
padding: theme('spacing.4'),
borderRadius: theme('borderRadius.lg'),
fontWeight: theme('fontWeight.semibold'),
'&:hover': {
opacity: 0.8,
},
},
'.btn-primary': {
backgroundColor: theme('colors.blue.500'),
color: theme('colors.white'),
},
})
}),
],
}
Framework Integration
React / Next.js
Next.js 13+ includes Tailwind by default with create-next-app:
npx create-next-app@latest my-app --tailwind
Manual setup:
npm install -D tailwindcss postcss autoprefixer
npx tailwindcss init -p
Example React component:
// components/Button.jsx
export default function Button({ children, variant = 'primary' }) {
const baseClasses = "px-4 py-2 rounded-lg font-medium transition-colors";
const variants = {
primary: "bg-blue-600 hover:bg-blue-700 text-white",
secondary: "bg-gray-200 hover:bg-gray-300 text-gray-900",
outline: "border-2 border-blue-600 text-blue-600 hover:bg-blue-50",
};
return (
<button className={`${baseClasses} ${variants[variant]}`}>
{children}
</button>
);
}
Using clsx for conditional classes:
import clsx from 'clsx';
function Button({ variant, size, children }) {
return (
<button
className={clsx(
'font-semibold rounded-lg transition-colors',
{
'bg-blue-600 text-white hover:bg-blue-700': variant === 'primary',
'bg-gray-200 text-gray-900 hover:bg-gray-300': variant === 'secondary',
'px-3 py-1.5 text-sm': size === 'sm',
'px-4 py-2 text-base': size === 'md',
'px-6 py-3 text-lg': size === 'lg',
}
)}
>
{children}
</button>
);
}
Vue / Nuxt
Nuxt 3:
npm install -D @nuxtjs/tailwindcss
// nuxt.config.ts
export default defineNuxtConfig({
modules: ['@nuxtjs/tailwindcss'],
})
Vue 3 component:
<template>
<button
:class="[
'px-4 py-2 rounded-lg font-medium transition-colors',
variantClasses
]"
>
<slot />
</button>
</template>
<script setup>
const props = defineProps({
variant: {
type: String,
default: 'primary'
}
});
const variantClasses = computed(() => {
const variants = {
primary: 'bg-blue-600 hover:bg-blue-700 text-white',
secondary: 'bg-gray-200 hover:bg-gray-300 text-gray-900',
};
return variants[props.variant];
});
</script>
Svelte / SvelteKit
npx svelte-add@latest tailwindcss
Svelte component:
<script>
export let variant = 'primary';
$: variantClasses = {
primary: 'bg-blue-600 hover:bg-blue-700 text-white',
secondary: 'bg-gray-200 hover:bg-gray-300 text-gray-900',
}[variant];
</script>
<button class="px-4 py-2 rounded-lg font-medium transition-colors {variantClasses}">
<slot />
</button>
Advanced Topics
@layer Directive
Organize custom styles into Tailwind's layers:
@tailwind base;
@tailwind components;
@tailwind utilities;
@layer base {
h1 {
@apply text-4xl font-bold;
}
a {
@apply text-blue-600 hover:underline;
}
}
@layer components {
.btn {
@apply px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700;
}
.card {
@apply bg-white rounded-lg shadow-md p-6;
}
}
@layer utilities {
.text-shadow {
text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.1);
}
}
@apply Directive
Extract repeated utilities into custom classes:
.btn-primary {
@apply px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700 transition-colors;
}
⚠️ Use sparingly: Only extract when you have true component repetition across multiple files.
Custom Variants
// tailwind.config.js
const plugin = require('tailwindcss/plugin')
module.exports = {
plugins: [
plugin(function({ addVariant }) {
// Custom variant for third child
addVariant('third', '&:nth-child(3)');
// Custom variant for optional elements
addVariant('optional', '&:optional');
// Custom variant for hocus (hover + focus)
addVariant('hocus', ['&:hover', '&:focus']);
})
],
}
<div class="third:bg-blue-500">Third child is blue</div>
<input class="optional:border-gray-300" />
<button class="hocus:bg-blue-700">Hover or focus</button>
Important Modifier
Force a utility to be !important:
<!-- Without important -->
<p class="text-red-500">Red text</p>
<!-- With important (overrides everything) -->
<p class="!text-red-500">Always red text</p>
Arbitrary Variants
Create one-off variants with square brackets:
<!-- Target specific data attribute -->
<div class="[&[data-state='active']]:bg-blue-500" data-state="active">
Blue when active
</div>
<!-- Target child elements -->
<ul class="[&>li]:mb-2">
<li>Item 1</li>
<li>Item 2</li>
</ul>
<!-- Complex selectors -->
<div class="[&:nth-child(3)]:text-red-500">
Third child is red
</div>
Performance Optimization
Content Configuration
Tell Tailwind exactly where to look for class names:
// tailwind.config.js
module.exports = {
content: [
'./src/**/*.{js,jsx,ts,tsx}',
'./public/index.html',
// Don't include:
// - node_modules (unless using Tailwind in a package)
// - build/dist folders
],
}
JIT (Just-In-Time) Mode
JIT is enabled by default in Tailwind 3+. It generates styles on-demand as you author your templates.
Benefits:
- Lightning fast build times
- All variants enabled by default
- Arbitrary values work everywhere
- Smaller CSS in development
- Better performance
Production Build
Tailwind automatically purges unused styles in production:
NODE_ENV=production npx tailwindcss -o output.css --minify
In build tools, set NODE_ENV=production:
// package.json
{
"scripts": {
"build": "NODE_ENV=production webpack build"
}
}
Bundle Size Tips
- Only import what you need - The config already does this via purging
- Use PurgeCSS - Automatically enabled in production
- Avoid safelist overuse - Only safelist truly dynamic classes
- Enable minification - Always in production builds
Best Practices
-
Use utility classes in HTML
- Keeps styles close to usage
- Easier to understand and modify
- No context switching
-
Extract components when needed
- Repeated patterns across multiple files
- True reusable components
- Not just to reduce class count in one place
-
Use consistent spacing scale
- Stick to Tailwind's spacing scale (4, 8, 16, 24, 32...)
- Use arbitrary values sparingly
- Creates visual rhythm
-
Mobile-first responsive design
- Start with mobile layout
- Add breakpoints for larger screens
md:for tablet,lg:for desktop
-
Organize classes logically
- Layout → Spacing → Sizing → Typography → Colors → Effects
- Example:
flex items-center px-4 py-2 text-lg font-bold bg-blue-500 rounded-lg shadow
-
Use editor extensions
- Tailwind CSS IntelliSense (VSCode)
- Auto-complete and class sorting
- Linting and validation
-
Combine with component frameworks
- Headless UI for accessible components
- Radix UI primitives
- Build design system on top
-
Don't fight the framework
- Use Tailwind's design tokens
- Extend theme rather than arbitrary values
- Embrace the constraints
-
When NOT to use Tailwind
- Simple static sites
- Teams that prefer CSS-in-JS
- Projects with strict CSS architecture requirements
- When you need maximum control over generated CSS
-
Performance considerations
- Configure content paths correctly
- Safelist only what's necessary
- Use JIT mode (default in v3)
- Minify in production
Accessibility
Focus States
<!-- Always include focus styles -->
<button class="
bg-blue-600
focus:outline-none
focus:ring-2
focus:ring-blue-500
focus:ring-offset-2
">
Accessible button
</button>
<!-- Focus-visible (keyboard only) -->
<a href="#" class="
focus:outline-none
focus-visible:ring-2
focus-visible:ring-blue-500
">
Link
</a>
Screen Reader Utilities
<!-- Screen reader only text -->
<button class="p-2">
<svg class="w-6 h-6" fill="currentColor">
<!-- Icon -->
</svg>
<span class="sr-only">Close menu</span>
</button>
<!-- Hide from screen readers -->
<div aria-hidden="true" class="text-gray-400">
Decorative element
</div>
Color Contrast
<!-- Good contrast -->
<div class="bg-gray-900 text-white">High contrast</div>
<!-- Ensure sufficient contrast -->
<p class="text-gray-600"><!-- Check contrast ratio --></p>
<!-- Use Tailwind's color scales appropriately -->
<!-- On white bg: gray-700, gray-800, gray-900 are safe -->
<!-- On dark bg: gray-100, gray-200, gray-300 are safe -->
Keyboard Navigation
<!-- Ensure tab order makes sense -->
<nav>
<a href="#" class="focus:ring-2 tabindex="0">Link 1</a>
<a href="#" class="focus:ring-2 tabindex="0">Link 2</a>
</nav>
<!-- Skip link for keyboard users -->
<a href="#main-content" class="sr-only focus:not-sr-only focus:absolute focus:top-0">
Skip to main content
</a>
Migration and Comparison
Migrating from Bootstrap
Bootstrap approach:
<div class="container">
<div class="row">
<div class="col-md-6">Column 1</div>
<div class="col-md-6">Column 2</div>
</div>
</div>
Tailwind equivalent:
<div class="container mx-auto px-4">
<div class="grid grid-cols-1 md:grid-cols-2 gap-4">
<div>Column 1</div>
<div>Column 2</div>
</div>
</div>
Tailwind vs CSS-in-JS
| Aspect | Tailwind | CSS-in-JS (styled-components) |
|---|---|---|
| Syntax | HTML classes | JavaScript objects/strings |
| Runtime | No runtime | Runtime overhead |
| File size | Small (purged) | Depends on usage |
| Theming | Config file | Theme provider |
| Learning curve | Learn utilities | Learn library API |
| Type safety | Via LSP | Native TypeScript |
Pros and Cons
Pros:
- ✅ Rapid development
- ✅ Consistent design system
- ✅ Small production bundle
- ✅ No naming fatigue
- ✅ Responsive by default
- ✅ Great developer experience
- ✅ Highly customizable
Cons:
- ❌ HTML can look cluttered
- ❌ Learning curve for utilities
- ❌ Team alignment needed
- ❌ Harder to enforce design patterns
- ❌ Some prefer separation of concerns
Tooling and Ecosystem
Editor Extensions
VS Code:
- Tailwind CSS IntelliSense: Auto-complete, syntax highlighting, linting
- Tailwind Fold: Fold long class strings
- Headwind: Auto-sort Tailwind classes
Settings for VSCode:
{
"tailwindCSS.experimental.classRegex": [
["class:\\s*?[\"'`]([^\"'`]*).*?[\"'`]", "[\"'`]([^\"'`]*).*?[\"'`]"],
],
"editor.quickSuggestions": {
"strings": true
}
}
Prettier Plugin
Auto-sort classes in consistent order:
npm install -D prettier prettier-plugin-tailwindcss
// .prettierrc
{
"plugins": ["prettier-plugin-tailwindcss"]
}
Headless UI
Unstyled, accessible UI components:
npm install @headlessui/react
import { Dialog } from '@headlessui/react'
function MyDialog({ isOpen, onClose }) {
return (
<Dialog open={isOpen} onClose={onClose} className="relative z-50">
<div className="fixed inset-0 bg-black/30" aria-hidden="true" />
<div className="fixed inset-0 flex items-center justify-center p-4">
<Dialog.Panel className="bg-white rounded-lg p-6 max-w-sm">
<Dialog.Title className="text-lg font-medium">Title</Dialog.Title>
<Dialog.Description>Description</Dialog.Description>
<button onClick={onClose} className="mt-4 px-4 py-2 bg-blue-600 text-white rounded">
Close
</button>
</Dialog.Panel>
</div>
</Dialog>
)
}
Component Libraries
Free:
- daisyUI: Component library built on Tailwind
- Flowbite: Open-source component library
- Preline: Free Tailwind components
- Mamba UI: Free Tailwind components
Commercial:
- Tailwind UI: Official component library (paid)
- Meraki UI: Premium components
Resources
Official Documentation
- Tailwind CSS Docs: https://tailwindcss.com/docs
- Tailwind Play (playground): https://play.tailwindcss.com/
- GitHub: https://github.com/tailwindlabs/tailwindcss
Learning Resources
- Tailwind CSS Tutorial (official): https://tailwindcss.com/docs/installation
- Scrimba Tailwind Course: Interactive lessons
- Tailwind from A to Z (YouTube): Adam Wathan
- Tailwind CSS From Scratch (Traversy Media)
Component Libraries
- Headless UI: https://headlessui.com/
- daisyUI: https://daisyui.com/
- Flowbite: https://flowbite.com/
- Tailwind UI: https://tailwindui.com/ (commercial)
Tools
- Tailwind CSS IntelliSense: VS Code extension
- Prettier Plugin: Auto-sort classes
- Tailwind Cheat Sheet: https://nerdcave.com/tailwind-cheat-sheet
- Tailwind Color Shades Generator: Generate custom color palettes
Icons
- Heroicons: https://heroicons.com/ (by Tailwind makers)
- Tabler Icons: https://tabler-icons.io/
- Lucide Icons: https://lucide.dev/
Community
- Discord: Official Tailwind Discord server
- Twitter: @tailwindcss
- GitHub Discussions: Community Q&A
- Reddit: r/tailwindcss
Last Updated: January 2025
Express.js
Express.js is a minimal and flexible Node.js web application framework that provides a robust set of features for building web and mobile applications. It's the de facto standard server framework for Node.js and is widely used for building RESTful APIs and web applications.
Table of Contents
- Introduction
- Installation and Setup
- Basic Application
- Routing
- Middleware
- Request and Response
- Error Handling
- Template Engines
- Static Files
- Database Integration
- Authentication
- RESTful API
- File Uploads
- Security Best Practices
- Testing
- Production Deployment
Introduction
Key Features:
- Minimal and unopinionated framework
- Robust routing system
- Focus on high performance
- Super-high test coverage
- HTTP helpers (redirection, caching, etc.)
- View system with 14+ template engines
- Content negotiation
- Executable for generating applications quickly
Use Cases:
- RESTful APIs
- Web applications
- Microservices
- Real-time applications (with Socket.io)
- Server-side rendering
- Proxy servers
Installation and Setup
Create New Project
# Create project directory
mkdir my-express-app
cd my-express-app
# Initialize npm project
npm init -y
# Install Express
npm install express
# Install development dependencies
npm install --save-dev nodemon typescript @types/node @types/express
TypeScript Setup
# Initialize TypeScript
npx tsc --init
tsconfig.json:
{
"compilerOptions": {
"target": "ES2020",
"module": "commonjs",
"lib": ["ES2020"],
"outDir": "./dist",
"rootDir": "./src",
"strict": true,
"esModuleInterop": true,
"skipLibCheck": true,
"forceConsistentCasingInFileNames": true,
"resolveJsonModule": true
},
"include": ["src/**/*"],
"exclude": ["node_modules"]
}
package.json scripts:
{
"scripts": {
"build": "tsc",
"start": "node dist/index.js",
"dev": "nodemon --exec ts-node src/index.ts",
"watch": "tsc --watch"
}
}
Basic Application
Minimal Express App
const express = require('express');
const app = express();
const PORT = 3000;
app.get('/', (req, res) => {
res.send('Hello World!');
});
app.listen(PORT, () => {
console.log(`Server running on http://localhost:${PORT}`);
});
TypeScript Version
import express, { Express, Request, Response } from 'express';
const app: Express = express();
const PORT = process.env.PORT || 3000;
app.get('/', (req: Request, res: Response) => {
res.send('Hello World!');
});
app.listen(PORT, () => {
console.log(`Server running on http://localhost:${PORT}`);
});
Application Structure
my-express-app/
├── src/
│ ├── index.ts # Entry point
│ ├── config/
│ │ ├── database.ts # Database configuration
│ │ └── environment.ts # Environment variables
│ ├── controllers/ # Route controllers
│ │ └── userController.ts
│ ├── middleware/ # Custom middleware
│ │ ├── auth.ts
│ │ └── errorHandler.ts
│ ├── models/ # Data models
│ │ └── User.ts
│ ├── routes/ # Route definitions
│ │ └── userRoutes.ts
│ ├── services/ # Business logic
│ │ └── userService.ts
│ └── utils/ # Utility functions
│ └── validators.ts
├── tests/ # Test files
├── dist/ # Compiled JavaScript
├── node_modules/
├── package.json
├── tsconfig.json
└── .env
Routing
Basic Routes
import express from 'express';
const app = express();
// GET request
app.get('/users', (req, res) => {
res.json({ message: 'Get all users' });
});
// POST request
app.post('/users', (req, res) => {
res.json({ message: 'Create user' });
});
// PUT request
app.put('/users/:id', (req, res) => {
res.json({ message: `Update user ${req.params.id}` });
});
// DELETE request
app.delete('/users/:id', (req, res) => {
res.json({ message: `Delete user ${req.params.id}` });
});
// PATCH request
app.patch('/users/:id', (req, res) => {
res.json({ message: `Partially update user ${req.params.id}` });
});
Route Parameters
// Single parameter
app.get('/users/:id', (req, res) => {
const userId = req.params.id;
res.json({ userId });
});
// Multiple parameters
app.get('/users/:userId/posts/:postId', (req, res) => {
const { userId, postId } = req.params;
res.json({ userId, postId });
});
// Optional parameters (using regex)
app.get('/users/:id(\\d+)?', (req, res) => {
res.json({ id: req.params.id || 'all' });
});
Query Parameters
// GET /search?q=express&limit=10
app.get('/search', (req, res) => {
const { q, limit = 10 } = req.query;
res.json({ query: q, limit });
});
Route Handlers
// Single callback
app.get('/example1', (req, res) => {
res.send('Single callback');
});
// Multiple callbacks
app.get('/example2',
(req, res, next) => {
console.log('First handler');
next();
},
(req, res) => {
res.send('Second handler');
}
);
// Array of callbacks
const cb1 = (req, res, next) => {
console.log('CB1');
next();
};
const cb2 = (req, res, next) => {
console.log('CB2');
next();
};
app.get('/example3', [cb1, cb2], (req, res) => {
res.send('Array of callbacks');
});
Express Router
// routes/userRoutes.ts
import { Router } from 'express';
import * as userController from '../controllers/userController';
const router = Router();
router.get('/', userController.getAllUsers);
router.get('/:id', userController.getUserById);
router.post('/', userController.createUser);
router.put('/:id', userController.updateUser);
router.delete('/:id', userController.deleteUser);
export default router;
// index.ts
import userRoutes from './routes/userRoutes';
app.use('/api/users', userRoutes);
Route Chaining
app.route('/users')
.get((req, res) => {
res.json({ message: 'Get all users' });
})
.post((req, res) => {
res.json({ message: 'Create user' });
});
app.route('/users/:id')
.get((req, res) => {
res.json({ message: 'Get user' });
})
.put((req, res) => {
res.json({ message: 'Update user' });
})
.delete((req, res) => {
res.json({ message: 'Delete user' });
});
Middleware
Middleware functions have access to request, response, and the next middleware function in the application's request-response cycle.
Built-in Middleware
import express from 'express';
const app = express();
// Parse JSON bodies
app.use(express.json());
// Parse URL-encoded bodies
app.use(express.urlencoded({ extended: true }));
// Serve static files
app.use(express.static('public'));
Application-Level Middleware
// Executed for every request
app.use((req, res, next) => {
console.log(`${req.method} ${req.url}`);
next();
});
// Executed for specific path
app.use('/api', (req, res, next) => {
console.log('API request');
next();
});
Router-Level Middleware
const router = express.Router();
router.use((req, res, next) => {
console.log('Router middleware');
next();
});
router.get('/users', (req, res) => {
res.json({ message: 'Users' });
});
app.use('/api', router);
Custom Middleware
// Logger middleware
const logger = (req: Request, res: Response, next: NextFunction) => {
const timestamp = new Date().toISOString();
console.log(`[${timestamp}] ${req.method} ${req.path}`);
next();
};
// Request timing middleware
const requestTimer = (req: Request, res: Response, next: NextFunction) => {
const start = Date.now();
res.on('finish', () => {
const duration = Date.now() - start;
console.log(`Request took ${duration}ms`);
});
next();
};
// Auth middleware
const authenticate = (req: Request, res: Response, next: NextFunction) => {
const token = req.headers.authorization;
if (!token) {
return res.status(401).json({ error: 'No token provided' });
}
try {
// Verify token
const decoded = verifyToken(token);
req.user = decoded;
next();
} catch (error) {
res.status(401).json({ error: 'Invalid token' });
}
};
// Usage
app.use(logger);
app.use(requestTimer);
app.use('/api/protected', authenticate);
Third-Party Middleware
// CORS
import cors from 'cors';
app.use(cors({
origin: 'http://localhost:3000',
credentials: true
}));
// Helmet (security headers)
import helmet from 'helmet';
app.use(helmet());
// Compression
import compression from 'compression';
app.use(compression());
// Cookie parser
import cookieParser from 'cookie-parser';
app.use(cookieParser());
// Morgan (HTTP request logger)
import morgan from 'morgan';
app.use(morgan('combined'));
// Express validator
import { body, validationResult } from 'express-validator';
app.post('/users',
body('email').isEmail(),
body('password').isLength({ min: 6 }),
(req, res) => {
const errors = validationResult(req);
if (!errors.isEmpty()) {
return res.status(400).json({ errors: errors.array() });
}
// Process request
}
);
Request and Response
Request Object
app.post('/example', (req: Request, res: Response) => {
// Request body (requires body-parser or express.json())
console.log(req.body);
// URL parameters
console.log(req.params);
// Query parameters
console.log(req.query);
// Headers
console.log(req.headers);
console.log(req.get('Content-Type'));
// Cookies (requires cookie-parser)
console.log(req.cookies);
console.log(req.signedCookies);
// Request URL info
console.log(req.protocol); // http or https
console.log(req.hostname); // Host name
console.log(req.path); // Path part of URL
console.log(req.originalUrl); // Original URL
console.log(req.baseUrl); // Base URL
// Request method
console.log(req.method); // GET, POST, etc.
// IP address
console.log(req.ip);
console.log(req.ips);
// Check content type
console.log(req.is('json'));
console.log(req.is('html'));
res.send('OK');
});
Response Object
app.get('/response-examples', (req: Request, res: Response) => {
// Send text
res.send('Hello World');
// Send JSON
res.json({ message: 'Success', data: [] });
// Set status code and send
res.status(201).json({ message: 'Created' });
// Send file
res.sendFile('/path/to/file.pdf');
// Download file
res.download('/path/to/file.pdf', 'filename.pdf');
// Redirect
res.redirect('/new-url');
res.redirect(301, '/permanent-redirect');
// Set headers
res.set('Content-Type', 'text/html');
res.set({
'Content-Type': 'text/html',
'X-Custom-Header': 'value'
});
// Set cookies
res.cookie('name', 'value', {
maxAge: 900000,
httpOnly: true,
secure: true
});
// Clear cookie
res.clearCookie('name');
// Render view (requires template engine)
res.render('index', { title: 'Home' });
// End response
res.end();
// Send status with message
res.sendStatus(404); // Sends "Not Found"
});
Response Status Codes
// Success
res.status(200).json({ message: 'OK' });
res.status(201).json({ message: 'Created' });
res.status(204).send(); // No Content
// Client Errors
res.status(400).json({ error: 'Bad Request' });
res.status(401).json({ error: 'Unauthorized' });
res.status(403).json({ error: 'Forbidden' });
res.status(404).json({ error: 'Not Found' });
res.status(422).json({ error: 'Unprocessable Entity' });
// Server Errors
res.status(500).json({ error: 'Internal Server Error' });
res.status(503).json({ error: 'Service Unavailable' });
Error Handling
Basic Error Handling
// Synchronous error
app.get('/sync-error', (req, res) => {
throw new Error('Synchronous error');
});
// Asynchronous error (must use next)
app.get('/async-error', (req, res, next) => {
setTimeout(() => {
try {
throw new Error('Async error');
} catch (err) {
next(err);
}
}, 100);
});
// Promise rejection
app.get('/promise-error', async (req, res, next) => {
try {
await someAsyncOperation();
res.json({ success: true });
} catch (err) {
next(err);
}
});
Error Handling Middleware
// Error handler (must have 4 parameters)
app.use((err: Error, req: Request, res: Response, next: NextFunction) => {
console.error(err.stack);
res.status(500).json({
error: {
message: err.message,
stack: process.env.NODE_ENV === 'development' ? err.stack : undefined
}
});
});
Custom Error Classes
// errors/AppError.ts
export class AppError extends Error {
statusCode: number;
isOperational: boolean;
constructor(message: string, statusCode: number) {
super(message);
this.statusCode = statusCode;
this.isOperational = true;
Error.captureStackTrace(this, this.constructor);
}
}
export class ValidationError extends AppError {
constructor(message: string) {
super(message, 400);
}
}
export class NotFoundError extends AppError {
constructor(message: string = 'Resource not found') {
super(message, 404);
}
}
export class UnauthorizedError extends AppError {
constructor(message: string = 'Unauthorized') {
super(message, 401);
}
}
// Usage in controllers
import { NotFoundError } from '../errors/AppError';
app.get('/users/:id', async (req, res, next) => {
try {
const user = await findUserById(req.params.id);
if (!user) {
throw new NotFoundError('User not found');
}
res.json(user);
} catch (err) {
next(err);
}
});
// Error handler
app.use((err: Error | AppError, req: Request, res: Response, next: NextFunction) => {
if (err instanceof AppError) {
return res.status(err.statusCode).json({
error: {
message: err.message,
statusCode: err.statusCode
}
});
}
// Unknown error
console.error('Unknown error:', err);
res.status(500).json({
error: {
message: 'Internal server error'
}
});
});
Async Error Wrapper
// utils/asyncHandler.ts
export const asyncHandler = (fn: Function) => {
return (req: Request, res: Response, next: NextFunction) => {
Promise.resolve(fn(req, res, next)).catch(next);
};
};
// Usage
app.get('/users', asyncHandler(async (req: Request, res: Response) => {
const users = await User.find();
res.json(users);
}));
404 Handler
// Catch 404 and forward to error handler
app.use((req, res, next) => {
res.status(404).json({
error: {
message: 'Route not found',
path: req.originalUrl
}
});
});
Template Engines
EJS (Embedded JavaScript)
npm install ejs
import express from 'express';
const app = express();
// Set view engine
app.set('view engine', 'ejs');
app.set('views', './views');
// Render template
app.get('/', (req, res) => {
res.render('index', {
title: 'Home Page',
user: { name: 'John' }
});
});
views/index.ejs:
<!DOCTYPE html>
<html>
<head>
<title><%= title %></title>
</head>
<body>
<h1>Welcome, <%= user.name %>!</h1>
<% if (user.isAdmin) { %>
<p>Admin panel</p>
<% } %>
<ul>
<% ['Item 1', 'Item 2', 'Item 3'].forEach(item => { %>
<li><%= item %></li>
<% }); %>
</ul>
</body>
</html>
Pug (formerly Jade)
npm install pug
app.set('view engine', 'pug');
app.set('views', './views');
app.get('/', (req, res) => {
res.render('index', { title: 'Home', message: 'Hello Pug!' });
});
views/index.pug:
html
head
title= title
body
h1= message
ul
each item in ['Item 1', 'Item 2', 'Item 3']
li= item
Handlebars
npm install express-handlebars
import { engine } from 'express-handlebars';
app.engine('handlebars', engine());
app.set('view engine', 'handlebars');
app.set('views', './views');
app.get('/', (req, res) => {
res.render('home', {
title: 'Home',
items: ['Item 1', 'Item 2', 'Item 3']
});
});
Static Files
Serving Static Files
// Serve from 'public' directory
app.use(express.static('public'));
// Now you can access:
// http://localhost:3000/images/logo.png
// http://localhost:3000/css/style.css
// http://localhost:3000/js/app.js
// Multiple static directories
app.use(express.static('public'));
app.use(express.static('files'));
// Virtual path prefix
app.use('/static', express.static('public'));
// Now: http://localhost:3000/static/images/logo.png
// Absolute path
import path from 'path';
app.use('/static', express.static(path.join(__dirname, 'public')));
Static File Options
app.use(express.static('public', {
maxAge: '1d', // Cache for 1 day
dotfiles: 'ignore', // Ignore dotfiles
index: 'index.html', // Directory index file
extensions: ['html'], // File extension fallbacks
setHeaders: (res, path) => {
res.set('X-Custom-Header', 'value');
}
}));
Database Integration
MongoDB with Mongoose
npm install mongoose
// config/database.ts
import mongoose from 'mongoose';
export const connectDatabase = async () => {
try {
await mongoose.connect(process.env.MONGODB_URI || 'mongodb://localhost:27017/myapp');
console.log('MongoDB connected');
} catch (error) {
console.error('MongoDB connection error:', error);
process.exit(1);
}
};
// models/User.ts
import mongoose, { Document, Schema } from 'mongoose';
export interface IUser extends Document {
name: string;
email: string;
password: string;
createdAt: Date;
}
const UserSchema = new Schema({
name: { type: String, required: true },
email: { type: String, required: true, unique: true },
password: { type: String, required: true },
createdAt: { type: Date, default: Date.now }
});
export default mongoose.model<IUser>('User', UserSchema);
// controllers/userController.ts
import User from '../models/User';
export const getAllUsers = async (req: Request, res: Response) => {
try {
const users = await User.find().select('-password');
res.json(users);
} catch (error) {
res.status(500).json({ error: 'Server error' });
}
};
export const createUser = async (req: Request, res: Response) => {
try {
const user = new User(req.body);
await user.save();
res.status(201).json(user);
} catch (error) {
res.status(400).json({ error: 'Invalid data' });
}
};
// index.ts
import { connectDatabase } from './config/database';
connectDatabase();
PostgreSQL with Sequelize
npm install sequelize pg pg-hstore
// config/database.ts
import { Sequelize } from 'sequelize';
export const sequelize = new Sequelize(
process.env.DB_NAME || 'myapp',
process.env.DB_USER || 'postgres',
process.env.DB_PASSWORD || 'password',
{
host: process.env.DB_HOST || 'localhost',
dialect: 'postgres',
logging: false
}
);
export const connectDatabase = async () => {
try {
await sequelize.authenticate();
console.log('PostgreSQL connected');
await sequelize.sync();
} catch (error) {
console.error('Database connection error:', error);
}
};
// models/User.ts
import { DataTypes, Model } from 'sequelize';
import { sequelize } from '../config/database';
export class User extends Model {
public id!: number;
public name!: string;
public email!: string;
public readonly createdAt!: Date;
}
User.init(
{
id: {
type: DataTypes.INTEGER,
autoIncrement: true,
primaryKey: true
},
name: {
type: DataTypes.STRING,
allowNull: false
},
email: {
type: DataTypes.STRING,
allowNull: false,
unique: true
}
},
{
sequelize,
tableName: 'users'
}
);
MySQL with mysql2
npm install mysql2
import mysql from 'mysql2/promise';
const pool = mysql.createPool({
host: 'localhost',
user: 'root',
password: 'password',
database: 'myapp',
waitForConnections: true,
connectionLimit: 10,
queueLimit: 0
});
app.get('/users', async (req, res) => {
try {
const [rows] = await pool.query('SELECT * FROM users');
res.json(rows);
} catch (error) {
res.status(500).json({ error: 'Database error' });
}
});
Authentication
JWT Authentication
npm install jsonwebtoken bcryptjs
npm install --save-dev @types/jsonwebtoken @types/bcryptjs
import jwt from 'jsonwebtoken';
import bcrypt from 'bcryptjs';
const JWT_SECRET = process.env.JWT_SECRET || 'your-secret-key';
// Register
app.post('/auth/register', async (req, res) => {
try {
const { email, password, name } = req.body;
// Check if user exists
const existingUser = await User.findOne({ email });
if (existingUser) {
return res.status(400).json({ error: 'User already exists' });
}
// Hash password
const hashedPassword = await bcrypt.hash(password, 10);
// Create user
const user = new User({
email,
password: hashedPassword,
name
});
await user.save();
// Generate token
const token = jwt.sign(
{ userId: user.id, email: user.email },
JWT_SECRET,
{ expiresIn: '7d' }
);
res.status(201).json({ token, user: { id: user.id, email, name } });
} catch (error) {
res.status(500).json({ error: 'Registration failed' });
}
});
// Login
app.post('/auth/login', async (req, res) => {
try {
const { email, password } = req.body;
// Find user
const user = await User.findOne({ email });
if (!user) {
return res.status(401).json({ error: 'Invalid credentials' });
}
// Verify password
const isValidPassword = await bcrypt.compare(password, user.password);
if (!isValidPassword) {
return res.status(401).json({ error: 'Invalid credentials' });
}
// Generate token
const token = jwt.sign(
{ userId: user.id, email: user.email },
JWT_SECRET,
{ expiresIn: '7d' }
);
res.json({
token,
user: { id: user.id, email: user.email, name: user.name }
});
} catch (error) {
res.status(500).json({ error: 'Login failed' });
}
});
// Auth middleware
interface AuthRequest extends Request {
user?: any;
}
const authenticate = (req: AuthRequest, res: Response, next: NextFunction) => {
try {
const token = req.headers.authorization?.split(' ')[1];
if (!token) {
return res.status(401).json({ error: 'No token provided' });
}
const decoded = jwt.verify(token, JWT_SECRET);
req.user = decoded;
next();
} catch (error) {
res.status(401).json({ error: 'Invalid token' });
}
};
// Protected route
app.get('/profile', authenticate, async (req: AuthRequest, res) => {
try {
const user = await User.findById(req.user.userId).select('-password');
res.json(user);
} catch (error) {
res.status(500).json({ error: 'Server error' });
}
});
Session-Based Authentication
npm install express-session connect-mongo
import session from 'express-session';
import MongoStore from 'connect-mongo';
app.use(session({
secret: process.env.SESSION_SECRET || 'your-secret',
resave: false,
saveUninitialized: false,
store: MongoStore.create({
mongoUrl: process.env.MONGODB_URI
}),
cookie: {
secure: process.env.NODE_ENV === 'production',
httpOnly: true,
maxAge: 1000 * 60 * 60 * 24 * 7 // 7 days
}
}));
// Login
app.post('/login', async (req, res) => {
const { email, password } = req.body;
const user = await User.findOne({ email });
if (!user || !(await bcrypt.compare(password, user.password))) {
return res.status(401).json({ error: 'Invalid credentials' });
}
req.session.userId = user.id;
res.json({ message: 'Logged in successfully' });
});
// Logout
app.post('/logout', (req, res) => {
req.session.destroy((err) => {
if (err) {
return res.status(500).json({ error: 'Logout failed' });
}
res.clearCookie('connect.sid');
res.json({ message: 'Logged out successfully' });
});
});
// Auth middleware
const requireAuth = (req: Request, res: Response, next: NextFunction) => {
if (!req.session.userId) {
return res.status(401).json({ error: 'Unauthorized' });
}
next();
};
RESTful API
Complete REST API Example
// routes/api/users.ts
import { Router } from 'express';
import {
getAllUsers,
getUserById,
createUser,
updateUser,
deleteUser
} from '../../controllers/userController';
import { authenticate } from '../../middleware/auth';
import { validateUser } from '../../middleware/validation';
const router = Router();
// GET /api/users - Get all users
router.get('/', authenticate, getAllUsers);
// GET /api/users/:id - Get user by ID
router.get('/:id', authenticate, getUserById);
// POST /api/users - Create new user
router.post('/', validateUser, createUser);
// PUT /api/users/:id - Update user
router.put('/:id', authenticate, validateUser, updateUser);
// DELETE /api/users/:id - Delete user
router.delete('/:id', authenticate, deleteUser);
export default router;
// controllers/userController.ts
import { Request, Response } from 'express';
import User from '../models/User';
export const getAllUsers = async (req: Request, res: Response) => {
try {
const page = parseInt(req.query.page as string) || 1;
const limit = parseInt(req.query.limit as string) || 10;
const skip = (page - 1) * limit;
const users = await User.find()
.select('-password')
.limit(limit)
.skip(skip);
const total = await User.countDocuments();
res.json({
users,
pagination: {
page,
limit,
total,
pages: Math.ceil(total / limit)
}
});
} catch (error) {
res.status(500).json({ error: 'Server error' });
}
};
export const getUserById = async (req: Request, res: Response) => {
try {
const user = await User.findById(req.params.id).select('-password');
if (!user) {
return res.status(404).json({ error: 'User not found' });
}
res.json(user);
} catch (error) {
res.status(500).json({ error: 'Server error' });
}
};
export const createUser = async (req: Request, res: Response) => {
try {
const user = new User(req.body);
await user.save();
const userResponse = user.toObject();
delete userResponse.password;
res.status(201).json(userResponse);
} catch (error) {
res.status(400).json({ error: 'Invalid data' });
}
};
export const updateUser = async (req: Request, res: Response) => {
try {
const user = await User.findByIdAndUpdate(
req.params.id,
req.body,
{ new: true, runValidators: true }
).select('-password');
if (!user) {
return res.status(404).json({ error: 'User not found' });
}
res.json(user);
} catch (error) {
res.status(400).json({ error: 'Invalid data' });
}
};
export const deleteUser = async (req: Request, res: Response) => {
try {
const user = await User.findByIdAndDelete(req.params.id);
if (!user) {
return res.status(404).json({ error: 'User not found' });
}
res.status(204).send();
} catch (error) {
res.status(500).json({ error: 'Server error' });
}
};
API Versioning
// v1 routes
import v1Router from './routes/v1';
app.use('/api/v1', v1Router);
// v2 routes
import v2Router from './routes/v2';
app.use('/api/v2', v2Router);
Rate Limiting
npm install express-rate-limit
import rateLimit from 'express-rate-limit';
const limiter = rateLimit({
windowMs: 15 * 60 * 1000, // 15 minutes
max: 100, // Limit each IP to 100 requests per windowMs
message: 'Too many requests from this IP'
});
app.use('/api/', limiter);
// Different limits for different routes
const authLimiter = rateLimit({
windowMs: 15 * 60 * 1000,
max: 5,
message: 'Too many login attempts'
});
app.use('/api/auth/login', authLimiter);
File Uploads
Multer for File Uploads
npm install multer
npm install --save-dev @types/multer
import multer from 'multer';
import path from 'path';
// Storage configuration
const storage = multer.diskStorage({
destination: (req, file, cb) => {
cb(null, 'uploads/');
},
filename: (req, file, cb) => {
const uniqueSuffix = Date.now() + '-' + Math.round(Math.random() * 1E9);
cb(null, file.fieldname + '-' + uniqueSuffix + path.extname(file.originalname));
}
});
// File filter
const fileFilter = (req: Request, file: Express.Multer.File, cb: multer.FileFilterCallback) => {
const allowedTypes = ['image/jpeg', 'image/png', 'image/gif'];
if (allowedTypes.includes(file.mimetype)) {
cb(null, true);
} else {
cb(new Error('Invalid file type'));
}
};
const upload = multer({
storage: storage,
limits: {
fileSize: 5 * 1024 * 1024 // 5MB
},
fileFilter: fileFilter
});
// Single file upload
app.post('/upload', upload.single('avatar'), (req, res) => {
if (!req.file) {
return res.status(400).json({ error: 'No file uploaded' });
}
res.json({
message: 'File uploaded successfully',
file: {
filename: req.file.filename,
path: req.file.path,
size: req.file.size
}
});
});
// Multiple files
app.post('/upload-multiple', upload.array('photos', 5), (req, res) => {
res.json({
message: 'Files uploaded successfully',
files: req.files
});
});
// Multiple fields
app.post('/upload-fields',
upload.fields([
{ name: 'avatar', maxCount: 1 },
{ name: 'gallery', maxCount: 5 }
]),
(req, res) => {
res.json({
message: 'Files uploaded successfully',
files: req.files
});
}
);
Security Best Practices
Essential Security Packages
npm install helmet cors express-rate-limit express-validator
npm install --save-dev @types/cors
import helmet from 'helmet';
import cors from 'cors';
import rateLimit from 'express-rate-limit';
// Helmet - Set security headers
app.use(helmet());
// CORS configuration
app.use(cors({
origin: process.env.ALLOWED_ORIGINS?.split(',') || 'http://localhost:3000',
credentials: true,
methods: ['GET', 'POST', 'PUT', 'DELETE', 'PATCH'],
allowedHeaders: ['Content-Type', 'Authorization']
}));
// Rate limiting
const limiter = rateLimit({
windowMs: 15 * 60 * 1000,
max: 100
});
app.use(limiter);
// Prevent parameter pollution
import hpp from 'hpp';
app.use(hpp());
// Sanitize data
import mongoSanitize from 'express-mongo-sanitize';
app.use(mongoSanitize());
// XSS protection
import xss from 'xss-clean';
app.use(xss());
Input Validation
import { body, param, validationResult } from 'express-validator';
app.post('/users',
body('email').isEmail().normalizeEmail(),
body('password').isLength({ min: 8 }).matches(/\d/).matches(/[a-zA-Z]/),
body('name').trim().isLength({ min: 2, max: 50 }),
(req, res) => {
const errors = validationResult(req);
if (!errors.isEmpty()) {
return res.status(400).json({ errors: errors.array() });
}
// Process request
}
);
SQL Injection Prevention
// Use parameterized queries
const [rows] = await pool.query(
'SELECT * FROM users WHERE email = ?',
[email]
);
// Use ORM/ODM
const user = await User.findOne({ email }); // Mongoose
HTTPS Enforcement
// Redirect HTTP to HTTPS
app.use((req, res, next) => {
if (req.header('x-forwarded-proto') !== 'https' && process.env.NODE_ENV === 'production') {
res.redirect(`https://${req.header('host')}${req.url}`);
} else {
next();
}
});
Testing
Jest and Supertest
npm install --save-dev jest supertest @types/jest @types/supertest ts-jest
jest.config.js:
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
testMatch: ['**/__tests__/**/*.ts', '**/?(*.)+(spec|test).ts'],
collectCoverageFrom: ['src/**/*.ts', '!src/**/*.d.ts']
};
tests/app.test.ts:
import request from 'supertest';
import app from '../src/app';
describe('User API', () => {
it('GET /api/users should return all users', async () => {
const response = await request(app)
.get('/api/users')
.expect('Content-Type', /json/)
.expect(200);
expect(response.body).toHaveProperty('users');
expect(Array.isArray(response.body.users)).toBe(true);
});
it('POST /api/users should create a user', async () => {
const newUser = {
name: 'John Doe',
email: 'john@example.com',
password: 'password123'
};
const response = await request(app)
.post('/api/users')
.send(newUser)
.expect('Content-Type', /json/)
.expect(201);
expect(response.body).toHaveProperty('id');
expect(response.body.email).toBe(newUser.email);
});
it('GET /api/users/:id should return a user', async () => {
const response = await request(app)
.get('/api/users/1')
.expect(200);
expect(response.body).toHaveProperty('id');
expect(response.body).toHaveProperty('name');
});
it('PUT /api/users/:id should update a user', async () => {
const updates = { name: 'Jane Doe' };
const response = await request(app)
.put('/api/users/1')
.send(updates)
.expect(200);
expect(response.body.name).toBe(updates.name);
});
it('DELETE /api/users/:id should delete a user', async () => {
await request(app)
.delete('/api/users/1')
.expect(204);
});
});
describe('Authentication', () => {
it('POST /auth/register should register a user', async () => {
const user = {
name: 'Test User',
email: 'test@example.com',
password: 'password123'
};
const response = await request(app)
.post('/auth/register')
.send(user)
.expect(201);
expect(response.body).toHaveProperty('token');
expect(response.body).toHaveProperty('user');
});
it('POST /auth/login should login a user', async () => {
const credentials = {
email: 'test@example.com',
password: 'password123'
};
const response = await request(app)
.post('/auth/login')
.send(credentials)
.expect(200);
expect(response.body).toHaveProperty('token');
});
});
Production Deployment
Environment Variables
.env:
NODE_ENV=production
PORT=3000
DATABASE_URL=mongodb://localhost:27017/myapp
JWT_SECRET=your-jwt-secret
SESSION_SECRET=your-session-secret
ALLOWED_ORIGINS=https://yourdomain.com
Process Manager (PM2)
npm install -g pm2
# Start application
pm2 start dist/index.js --name "my-app"
# Start with cluster mode
pm2 start dist/index.js -i max --name "my-app"
# Save configuration
pm2 save
# Startup script
pm2 startup
ecosystem.config.js:
module.exports = {
apps: [{
name: 'my-app',
script: './dist/index.js',
instances: 'max',
exec_mode: 'cluster',
env: {
NODE_ENV: 'production',
PORT: 3000
},
error_file: './logs/error.log',
out_file: './logs/out.log',
log_date_format: 'YYYY-MM-DD HH:mm:ss'
}]
};
Docker Deployment
Dockerfile:
FROM node:18-alpine
WORKDIR /app
COPY package*.json ./
RUN npm ci --only=production
COPY . .
RUN npm run build
EXPOSE 3000
CMD ["node", "dist/index.js"]
docker-compose.yml:
version: '3.8'
services:
app:
build: .
ports:
- "3000:3000"
environment:
- NODE_ENV=production
- DATABASE_URL=mongodb://mongo:27017/myapp
depends_on:
- mongo
mongo:
image: mongo:6
volumes:
- mongo-data:/data/db
ports:
- "27017:27017"
volumes:
mongo-data:
Nginx Reverse Proxy
server {
listen 80;
server_name yourdomain.com;
location / {
proxy_pass http://localhost:3000;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection 'upgrade';
proxy_set_header Host $host;
proxy_cache_bypass $http_upgrade;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
}
Performance Optimization
// Compression
import compression from 'compression';
app.use(compression());
// Response caching
import apicache from 'apicache';
const cache = apicache.middleware;
app.get('/api/users', cache('5 minutes'), getAllUsers);
// Database connection pooling
mongoose.connect(uri, {
maxPoolSize: 10,
minPoolSize: 5
});
// Clustering
import cluster from 'cluster';
import os from 'os';
if (cluster.isPrimary) {
const cpuCount = os.cpus().length;
for (let i = 0; i < cpuCount; i++) {
cluster.fork();
}
cluster.on('exit', (worker) => {
console.log(`Worker ${worker.process.pid} died`);
cluster.fork();
});
} else {
app.listen(PORT);
}
Resources
- Official Documentation: https://expressjs.com/
- GitHub Repository: https://github.com/expressjs/express
- Express Generator: https://expressjs.com/en/starter/generator.html
- Best Practices: https://expressjs.com/en/advanced/best-practice-performance.html
Express.js remains the most popular Node.js framework due to its simplicity, flexibility, and robust ecosystem. Its minimalist approach allows developers to structure applications as they see fit, making it suitable for everything from small APIs to large-scale enterprise applications.
NestJS
NestJS is a progressive Node.js framework for building efficient, reliable, and scalable server-side applications. It uses TypeScript by default and combines elements of Object-Oriented Programming (OOP), Functional Programming (FP), and Functional Reactive Programming (FRP). NestJS is heavily inspired by Angular's architecture and provides a robust application structure out of the box.
Table of Contents
- Introduction
- Installation and Setup
- Core Concepts
- Controllers
- Providers and Dependency Injection
- Modules
- Middleware
- Exception Filters
- Pipes
- Guards
- Interceptors
- Database Integration
- Authentication and Authorization
- GraphQL
- Microservices
- Testing
- Best Practices
- Production Deployment
Introduction
Key Features:
- TypeScript-first framework with full TypeScript support
- Modular architecture with dependency injection
- Built on top of Express (or Fastify)
- Extensive ecosystem and CLI tools
- WebSockets and GraphQL support
- Microservices architecture support
- Comprehensive testing utilities
- OpenAPI (Swagger) integration
- Excellent documentation
Use Cases:
- Enterprise-grade REST APIs
- GraphQL APIs
- Microservices architectures
- Real-time applications with WebSockets
- Server-side rendered applications
- Backend for mobile applications
- Monolithic or distributed systems
Philosophy: NestJS provides an opinionated structure while remaining flexible, making it ideal for large teams and enterprise applications where maintainability and scalability are crucial.
Installation and Setup
Prerequisites
# Node.js 16+ and npm required
node --version
npm --version
Create New Project
# Install NestJS CLI globally
npm install -g @nestjs/cli
# Create new project
nest new my-nest-app
# Navigate to project
cd my-nest-app
# Start development server
npm run start:dev
Manual Setup
# Create project directory
mkdir my-nest-app
cd my-nest-app
# Initialize npm
npm init -y
# Install core dependencies
npm install @nestjs/common @nestjs/core @nestjs/platform-express reflect-metadata rxjs
# Install development dependencies
npm install -D @nestjs/cli @nestjs/schematics typescript @types/node @types/express ts-node
Project Structure
my-nest-app/
├── src/
│ ├── main.ts # Application entry point
│ ├── app.module.ts # Root module
│ ├── app.controller.ts # Root controller
│ ├── app.service.ts # Root service
│ ├── modules/ # Feature modules
│ │ ├── users/
│ │ │ ├── users.module.ts
│ │ │ ├── users.controller.ts
│ │ │ ├── users.service.ts
│ │ │ ├── dto/ # Data Transfer Objects
│ │ │ ├── entities/ # Database entities
│ │ │ └── interfaces/ # TypeScript interfaces
│ │ └── auth/
│ ├── common/ # Shared utilities
│ │ ├── guards/
│ │ ├── interceptors/
│ │ ├── pipes/
│ │ ├── filters/
│ │ └── decorators/
│ └── config/ # Configuration files
├── test/ # E2E tests
├── nest-cli.json # NestJS CLI configuration
├── tsconfig.json # TypeScript configuration
└── package.json
Configuration Files
tsconfig.json:
{
"compilerOptions": {
"module": "commonjs",
"declaration": true,
"removeComments": true,
"emitDecoratorMetadata": true,
"experimentalDecorators": true,
"allowSyntheticDefaultImports": true,
"target": "ES2021",
"sourceMap": true,
"outDir": "./dist",
"baseUrl": "./",
"incremental": true,
"skipLibCheck": true,
"strictNullChecks": false,
"noImplicitAny": false,
"strictBindCallApply": false,
"forceConsistentCasingInFileNames": false,
"noFallthroughCasesInSwitch": false
}
}
nest-cli.json:
{
"collection": "@nestjs/schematics",
"sourceRoot": "src",
"compilerOptions": {
"deleteOutDir": true
}
}
Core Concepts
Application Bootstrap
src/main.ts:
import { NestFactory } from '@nestjs/core';
import { AppModule } from './app.module';
import { ValidationPipe } from '@nestjs/common';
import { DocumentBuilder, SwaggerModule } from '@nestjs/swagger';
async function bootstrap() {
const app = await NestFactory.create(AppModule);
// Enable CORS
app.enableCors();
// Global prefix
app.setGlobalPrefix('api/v1');
// Global validation pipe
app.useGlobalPipes(new ValidationPipe({
whitelist: true,
forbidNonWhitelisted: true,
transform: true,
}));
// Swagger documentation
const config = new DocumentBuilder()
.setTitle('My API')
.setDescription('API documentation')
.setVersion('1.0')
.addBearerAuth()
.build();
const document = SwaggerModule.createDocument(app, config);
SwaggerModule.setup('api/docs', app, document);
await app.listen(3000);
console.log(`Application is running on: ${await app.getUrl()}`);
}
bootstrap();
Controllers
Controllers handle incoming requests and return responses to the client.
Basic Controller
import { Controller, Get, Post, Put, Delete, Body, Param, Query, HttpCode, HttpStatus } from '@nestjs/common';
import { UsersService } from './users.service';
import { CreateUserDto } from './dto/create-user.dto';
import { UpdateUserDto } from './dto/update-user.dto';
@Controller('users')
export class UsersController {
constructor(private readonly usersService: UsersService) {}
@Get()
findAll(@Query('page') page: number = 1, @Query('limit') limit: number = 10) {
return this.usersService.findAll(page, limit);
}
@Get(':id')
findOne(@Param('id') id: string) {
return this.usersService.findOne(+id);
}
@Post()
@HttpCode(HttpStatus.CREATED)
create(@Body() createUserDto: CreateUserDto) {
return this.usersService.create(createUserDto);
}
@Put(':id')
update(@Param('id') id: string, @Body() updateUserDto: UpdateUserDto) {
return this.usersService.update(+id, updateUserDto);
}
@Delete(':id')
@HttpCode(HttpStatus.NO_CONTENT)
remove(@Param('id') id: string) {
return this.usersService.remove(+id);
}
}
Advanced Controller Features
import {
Controller,
Get,
Post,
UseGuards,
UseInterceptors,
UsePipes,
Req,
Res,
Headers,
Session
} from '@nestjs/common';
import { Request, Response } from 'express';
import { AuthGuard } from '@nestjs/passport';
import { LoggingInterceptor } from '../common/interceptors/logging.interceptor';
import { ValidationPipe } from '@nestjs/common';
@Controller('products')
@UseGuards(AuthGuard('jwt'))
@UseInterceptors(LoggingInterceptor)
export class ProductsController {
@Get()
async findAll(@Req() request: Request, @Headers('authorization') auth: string) {
return {
data: [],
user: request.user,
};
}
@Post()
@UsePipes(new ValidationPipe({ transform: true }))
async create(@Body() body: any, @Res() response: Response) {
const result = await this.createProduct(body);
return response.status(201).json(result);
}
@Get('download')
async download(@Res() res: Response) {
res.download('./files/report.pdf');
}
}
Providers and Dependency Injection
Providers are the fundamental concept in NestJS. Services, repositories, factories, and helpers can all be providers.
Basic Service
import { Injectable, NotFoundException } from '@nestjs/common';
import { CreateUserDto } from './dto/create-user.dto';
import { UpdateUserDto } from './dto/update-user.dto';
@Injectable()
export class UsersService {
private users = [];
findAll(page: number, limit: number) {
const start = (page - 1) * limit;
const end = start + limit;
return {
data: this.users.slice(start, end),
total: this.users.length,
page,
limit,
};
}
findOne(id: number) {
const user = this.users.find(u => u.id === id);
if (!user) {
throw new NotFoundException(`User with ID ${id} not found`);
}
return user;
}
create(createUserDto: CreateUserDto) {
const user = {
id: this.users.length + 1,
...createUserDto,
createdAt: new Date(),
};
this.users.push(user);
return user;
}
update(id: number, updateUserDto: UpdateUserDto) {
const user = this.findOne(id);
Object.assign(user, updateUserDto);
return user;
}
remove(id: number) {
const index = this.users.findIndex(u => u.id === id);
if (index === -1) {
throw new NotFoundException(`User with ID ${id} not found`);
}
this.users.splice(index, 1);
}
}
Custom Provider
// Value provider
const configProvider = {
provide: 'CONFIG',
useValue: {
apiKey: process.env.API_KEY,
apiUrl: process.env.API_URL,
},
};
// Factory provider
const databaseProvider = {
provide: 'DATABASE_CONNECTION',
useFactory: async () => {
const connection = await createConnection({
type: 'postgres',
host: 'localhost',
port: 5432,
});
return connection;
},
};
// Class provider
const loggerProvider = {
provide: 'LOGGER',
useClass: CustomLogger,
};
// Usage in module
@Module({
providers: [configProvider, databaseProvider, loggerProvider],
})
export class AppModule {}
Dependency Injection
import { Injectable, Inject } from '@nestjs/common';
@Injectable()
export class ProductsService {
constructor(
@Inject('CONFIG') private config: any,
@Inject('DATABASE_CONNECTION') private db: any,
private readonly usersService: UsersService,
) {}
async findAll() {
const apiUrl = this.config.apiUrl;
const users = await this.usersService.findAll(1, 10);
const products = await this.db.query('SELECT * FROM products');
return products;
}
}
Modules
Modules organize the application structure and enable modular architecture.
Feature Module
import { Module } from '@nestjs/common';
import { UsersController } from './users.controller';
import { UsersService } from './users.service';
import { TypeOrmModule } from '@nestjs/typeorm';
import { User } from './entities/user.entity';
@Module({
imports: [TypeOrmModule.forFeature([User])],
controllers: [UsersController],
providers: [UsersService],
exports: [UsersService], // Export to use in other modules
})
export class UsersModule {}
Root Module
import { Module } from '@nestjs/common';
import { ConfigModule } from '@nestjs/config';
import { TypeOrmModule } from '@nestjs/typeorm';
import { UsersModule } from './modules/users/users.module';
import { AuthModule } from './modules/auth/auth.module';
import { ProductsModule } from './modules/products/products.module';
@Module({
imports: [
ConfigModule.forRoot({
isGlobal: true,
envFilePath: '.env',
}),
TypeOrmModule.forRoot({
type: 'postgres',
host: process.env.DB_HOST,
port: parseInt(process.env.DB_PORT),
username: process.env.DB_USERNAME,
password: process.env.DB_PASSWORD,
database: process.env.DB_NAME,
autoLoadEntities: true,
synchronize: process.env.NODE_ENV !== 'production',
}),
UsersModule,
AuthModule,
ProductsModule,
],
})
export class AppModule {}
Global Module
import { Module, Global } from '@nestjs/common';
import { LoggerService } from './logger.service';
@Global()
@Module({
providers: [LoggerService],
exports: [LoggerService],
})
export class LoggerModule {}
Dynamic Module
import { Module, DynamicModule } from '@nestjs/common';
import { DatabaseService } from './database.service';
@Module({})
export class DatabaseModule {
static forRoot(options: DatabaseOptions): DynamicModule {
return {
module: DatabaseModule,
providers: [
{
provide: 'DATABASE_OPTIONS',
useValue: options,
},
DatabaseService,
],
exports: [DatabaseService],
};
}
}
// Usage
@Module({
imports: [
DatabaseModule.forRoot({
host: 'localhost',
port: 5432,
}),
],
})
export class AppModule {}
Middleware
Middleware functions execute before the route handler.
Functional Middleware
import { Request, Response, NextFunction } from 'express';
export function logger(req: Request, res: Response, next: NextFunction) {
console.log(`[${new Date().toISOString()}] ${req.method} ${req.url}`);
next();
}
Class-based Middleware
import { Injectable, NestMiddleware } from '@nestjs/common';
import { Request, Response, NextFunction } from 'express';
@Injectable()
export class LoggerMiddleware implements NestMiddleware {
use(req: Request, res: Response, next: NextFunction) {
console.log(`[${new Date().toISOString()}] ${req.method} ${req.url}`);
next();
}
}
// Apply in module
import { Module, NestModule, MiddlewareConsumer } from '@nestjs/common';
@Module({})
export class AppModule implements NestModule {
configure(consumer: MiddlewareConsumer) {
consumer
.apply(LoggerMiddleware)
.forRoutes('*'); // Apply to all routes
// Or specific routes
consumer
.apply(LoggerMiddleware)
.forRoutes({ path: 'users', method: RequestMethod.GET });
}
}
Exception Filters
Exception filters handle all thrown exceptions.
Custom Exception Filter
import {
ExceptionFilter,
Catch,
ArgumentsHost,
HttpException,
HttpStatus
} from '@nestjs/common';
import { Request, Response } from 'express';
@Catch(HttpException)
export class HttpExceptionFilter implements ExceptionFilter {
catch(exception: HttpException, host: ArgumentsHost) {
const ctx = host.switchToHttp();
const response = ctx.getResponse<Response>();
const request = ctx.getRequest<Request>();
const status = exception.getStatus();
const exceptionResponse = exception.getResponse();
response.status(status).json({
statusCode: status,
timestamp: new Date().toISOString(),
path: request.url,
method: request.method,
message: exceptionResponse['message'] || exception.message,
});
}
}
// Apply globally
app.useGlobalFilters(new HttpExceptionFilter());
// Or use in controller
@Controller('users')
@UseFilters(HttpExceptionFilter)
export class UsersController {}
All Exceptions Filter
@Catch()
export class AllExceptionsFilter implements ExceptionFilter {
catch(exception: unknown, host: ArgumentsHost) {
const ctx = host.switchToHttp();
const response = ctx.getResponse<Response>();
const request = ctx.getRequest<Request>();
const status =
exception instanceof HttpException
? exception.getStatus()
: HttpStatus.INTERNAL_SERVER_ERROR;
const message =
exception instanceof HttpException
? exception.message
: 'Internal server error';
response.status(status).json({
statusCode: status,
timestamp: new Date().toISOString(),
path: request.url,
message,
});
}
}
Pipes
Pipes transform input data or validate it before it reaches the route handler.
Built-in Validation
import { IsString, IsEmail, IsInt, Min, Max, IsOptional } from 'class-validator';
export class CreateUserDto {
@IsString()
@Length(3, 50)
name: string;
@IsEmail()
email: string;
@IsInt()
@Min(18)
@Max(120)
age: number;
@IsString()
@IsOptional()
bio?: string;
}
// Controller
@Post()
create(@Body() createUserDto: CreateUserDto) {
return this.usersService.create(createUserDto);
}
Custom Pipe
import { PipeTransform, Injectable, ArgumentMetadata, BadRequestException } from '@nestjs/common';
@Injectable()
export class ParseIntPipe implements PipeTransform<string, number> {
transform(value: string, metadata: ArgumentMetadata): number {
const val = parseInt(value, 10);
if (isNaN(val)) {
throw new BadRequestException('Validation failed');
}
return val;
}
}
// Usage
@Get(':id')
findOne(@Param('id', ParseIntPipe) id: number) {
return this.usersService.findOne(id);
}
Transformation Pipe
import { PipeTransform, Injectable, ArgumentMetadata } from '@nestjs/common';
@Injectable()
export class TrimPipe implements PipeTransform {
transform(value: any, metadata: ArgumentMetadata) {
if (typeof value === 'string') {
return value.trim();
}
if (typeof value === 'object') {
Object.keys(value).forEach(key => {
if (typeof value[key] === 'string') {
value[key] = value[key].trim();
}
});
}
return value;
}
}
Guards
Guards determine whether a request should be handled by the route handler.
Authentication Guard
import { Injectable, CanActivate, ExecutionContext, UnauthorizedException } from '@nestjs/common';
import { JwtService } from '@nestjs/jwt';
@Injectable()
export class AuthGuard implements CanActivate {
constructor(private jwtService: JwtService) {}
async canActivate(context: ExecutionContext): Promise<boolean> {
const request = context.switchToHttp().getRequest();
const token = this.extractTokenFromHeader(request);
if (!token) {
throw new UnauthorizedException('No token provided');
}
try {
const payload = await this.jwtService.verifyAsync(token);
request.user = payload;
return true;
} catch {
throw new UnauthorizedException('Invalid token');
}
}
private extractTokenFromHeader(request: any): string | undefined {
const [type, token] = request.headers.authorization?.split(' ') ?? [];
return type === 'Bearer' ? token : undefined;
}
}
Roles Guard
import { Injectable, CanActivate, ExecutionContext } from '@nestjs/common';
import { Reflector } from '@nestjs/core';
import { SetMetadata } from '@nestjs/common';
export const ROLES_KEY = 'roles';
export const Roles = (...roles: string[]) => SetMetadata(ROLES_KEY, roles);
@Injectable()
export class RolesGuard implements CanActivate {
constructor(private reflector: Reflector) {}
canActivate(context: ExecutionContext): boolean {
const requiredRoles = this.reflector.getAllAndOverride<string[]>(ROLES_KEY, [
context.getHandler(),
context.getClass(),
]);
if (!requiredRoles) {
return true;
}
const { user } = context.switchToHttp().getRequest();
return requiredRoles.some((role) => user.roles?.includes(role));
}
}
// Usage
@Controller('admin')
@UseGuards(AuthGuard, RolesGuard)
export class AdminController {
@Get()
@Roles('admin')
findAll() {
return 'This route is only for admins';
}
}
Interceptors
Interceptors can transform the result returned from a function or extend basic function behavior.
Logging Interceptor
import {
Injectable,
NestInterceptor,
ExecutionContext,
CallHandler,
} from '@nestjs/common';
import { Observable } from 'rxjs';
import { tap } from 'rxjs/operators';
@Injectable()
export class LoggingInterceptor implements NestInterceptor {
intercept(context: ExecutionContext, next: CallHandler): Observable<any> {
const now = Date.now();
const request = context.switchToHttp().getRequest();
const method = request.method;
const url = request.url;
return next.handle().pipe(
tap(() => {
const responseTime = Date.now() - now;
console.log(`${method} ${url} - ${responseTime}ms`);
}),
);
}
}
Transform Interceptor
import {
Injectable,
NestInterceptor,
ExecutionContext,
CallHandler,
} from '@nestjs/common';
import { Observable } from 'rxjs';
import { map } from 'rxjs/operators';
export interface Response<T> {
data: T;
statusCode: number;
timestamp: string;
}
@Injectable()
export class TransformInterceptor<T> implements NestInterceptor<T, Response<T>> {
intercept(context: ExecutionContext, next: CallHandler): Observable<Response<T>> {
return next.handle().pipe(
map(data => ({
data,
statusCode: context.switchToHttp().getResponse().statusCode,
timestamp: new Date().toISOString(),
})),
);
}
}
Caching Interceptor
import {
Injectable,
NestInterceptor,
ExecutionContext,
CallHandler,
} from '@nestjs/common';
import { Observable, of } from 'rxjs';
import { tap } from 'rxjs/operators';
@Injectable()
export class CacheInterceptor implements NestInterceptor {
private cache = new Map();
intercept(context: ExecutionContext, next: CallHandler): Observable<any> {
const request = context.switchToHttp().getRequest();
const key = request.url;
if (this.cache.has(key)) {
return of(this.cache.get(key));
}
return next.handle().pipe(
tap(response => {
this.cache.set(key, response);
}),
);
}
}
Database Integration
TypeORM Integration
Installation:
npm install @nestjs/typeorm typeorm pg
Entity:
import { Entity, Column, PrimaryGeneratedColumn, CreateDateColumn, UpdateDateColumn } from 'typeorm';
@Entity('users')
export class User {
@PrimaryGeneratedColumn()
id: number;
@Column({ unique: true })
email: string;
@Column()
name: string;
@Column()
password: string;
@Column({ default: true })
isActive: boolean;
@CreateDateColumn()
createdAt: Date;
@UpdateDateColumn()
updatedAt: Date;
}
Service with Repository:
import { Injectable } from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository } from 'typeorm';
import { User } from './entities/user.entity';
import { CreateUserDto } from './dto/create-user.dto';
@Injectable()
export class UsersService {
constructor(
@InjectRepository(User)
private usersRepository: Repository<User>,
) {}
async findAll(): Promise<User[]> {
return this.usersRepository.find();
}
async findOne(id: number): Promise<User> {
return this.usersRepository.findOne({ where: { id } });
}
async create(createUserDto: CreateUserDto): Promise<User> {
const user = this.usersRepository.create(createUserDto);
return this.usersRepository.save(user);
}
async update(id: number, updateData: Partial<User>): Promise<User> {
await this.usersRepository.update(id, updateData);
return this.findOne(id);
}
async remove(id: number): Promise<void> {
await this.usersRepository.delete(id);
}
}
Prisma Integration
Installation:
npm install @prisma/client
npm install -D prisma
npx prisma init
Prisma Service:
import { Injectable, OnModuleInit, OnModuleDestroy } from '@nestjs/common';
import { PrismaClient } from '@prisma/client';
@Injectable()
export class PrismaService extends PrismaClient implements OnModuleInit, OnModuleDestroy {
async onModuleInit() {
await this.$connect();
}
async onModuleDestroy() {
await this.$disconnect();
}
}
Authentication and Authorization
JWT Authentication
Installation:
npm install @nestjs/jwt @nestjs/passport passport passport-jwt
npm install -D @types/passport-jwt
Auth Module:
import { Module } from '@nestjs/common';
import { JwtModule } from '@nestjs/jwt';
import { PassportModule } from '@nestjs/passport';
import { AuthService } from './auth.service';
import { AuthController } from './auth.controller';
import { JwtStrategy } from './strategies/jwt.strategy';
import { UsersModule } from '../users/users.module';
@Module({
imports: [
UsersModule,
PassportModule,
JwtModule.register({
secret: process.env.JWT_SECRET,
signOptions: { expiresIn: '1d' },
}),
],
controllers: [AuthController],
providers: [AuthService, JwtStrategy],
exports: [AuthService],
})
export class AuthModule {}
Auth Service:
import { Injectable, UnauthorizedException } from '@nestjs/common';
import { JwtService } from '@nestjs/jwt';
import { UsersService } from '../users/users.service';
import * as bcrypt from 'bcrypt';
@Injectable()
export class AuthService {
constructor(
private usersService: UsersService,
private jwtService: JwtService,
) {}
async signIn(email: string, password: string) {
const user = await this.usersService.findByEmail(email);
if (!user) {
throw new UnauthorizedException('Invalid credentials');
}
const isPasswordValid = await bcrypt.compare(password, user.password);
if (!isPasswordValid) {
throw new UnauthorizedException('Invalid credentials');
}
const payload = { sub: user.id, email: user.email };
return {
access_token: await this.jwtService.signAsync(payload),
user: {
id: user.id,
email: user.email,
name: user.name,
},
};
}
async signUp(email: string, password: string, name: string) {
const hashedPassword = await bcrypt.hash(password, 10);
const user = await this.usersService.create({
email,
password: hashedPassword,
name,
});
const payload = { sub: user.id, email: user.email };
return {
access_token: await this.jwtService.signAsync(payload),
user: {
id: user.id,
email: user.email,
name: user.name,
},
};
}
}
JWT Strategy:
import { Injectable } from '@nestjs/common';
import { PassportStrategy } from '@nestjs/passport';
import { ExtractJwt, Strategy } from 'passport-jwt';
@Injectable()
export class JwtStrategy extends PassportStrategy(Strategy) {
constructor() {
super({
jwtFromRequest: ExtractJwt.fromAuthHeaderAsBearerToken(),
ignoreExpiration: false,
secretOrKey: process.env.JWT_SECRET,
});
}
async validate(payload: any) {
return { userId: payload.sub, email: payload.email };
}
}
GraphQL
Installation:
npm install @nestjs/graphql @nestjs/apollo @apollo/server graphql
Configuration:
import { Module } from '@nestjs/common';
import { GraphQLModule } from '@nestjs/graphql';
import { ApolloDriver, ApolloDriverConfig } from '@nestjs/apollo';
@Module({
imports: [
GraphQLModule.forRoot<ApolloDriverConfig>({
driver: ApolloDriver,
autoSchemaFile: true,
playground: true,
}),
],
})
export class AppModule {}
Resolver:
import { Resolver, Query, Mutation, Args, Int } from '@nestjs/graphql';
import { User } from './models/user.model';
import { UsersService } from './users.service';
import { CreateUserInput } from './dto/create-user.input';
@Resolver(() => User)
export class UsersResolver {
constructor(private usersService: UsersService) {}
@Query(() => [User], { name: 'users' })
findAll() {
return this.usersService.findAll();
}
@Query(() => User, { name: 'user' })
findOne(@Args('id', { type: () => Int }) id: number) {
return this.usersService.findOne(id);
}
@Mutation(() => User)
createUser(@Args('createUserInput') createUserInput: CreateUserInput) {
return this.usersService.create(createUserInput);
}
}
Microservices
TCP Microservice
Server:
import { NestFactory } from '@nestjs/core';
import { Transport, MicroserviceOptions } from '@nestjs/microservices';
import { AppModule } from './app.module';
async function bootstrap() {
const app = await NestFactory.createMicroservice<MicroserviceOptions>(AppModule, {
transport: Transport.TCP,
options: {
host: '127.0.0.1',
port: 8877,
},
});
await app.listen();
}
bootstrap();
Controller:
import { Controller } from '@nestjs/common';
import { MessagePattern, Payload } from '@nestjs/microservices';
@Controller()
export class MathController {
@MessagePattern({ cmd: 'sum' })
accumulate(@Payload() data: number[]): number {
return (data || []).reduce((a, b) => a + b);
}
}
Client:
import { Injectable } from '@nestjs/common';
import { ClientProxy, ClientProxyFactory, Transport } from '@nestjs/microservices';
@Injectable()
export class AppService {
private client: ClientProxy;
constructor() {
this.client = ClientProxyFactory.create({
transport: Transport.TCP,
options: {
host: '127.0.0.1',
port: 8877,
},
});
}
async accumulate() {
const pattern = { cmd: 'sum' };
const payload = [1, 2, 3];
return this.client.send<number>(pattern, payload);
}
}
Testing
Unit Testing
import { Test, TestingModule } from '@nestjs/testing';
import { UsersService } from './users.service';
import { getRepositoryToken } from '@nestjs/typeorm';
import { User } from './entities/user.entity';
describe('UsersService', () => {
let service: UsersService;
const mockUserRepository = {
find: jest.fn(),
findOne: jest.fn(),
create: jest.fn(),
save: jest.fn(),
update: jest.fn(),
delete: jest.fn(),
};
beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
providers: [
UsersService,
{
provide: getRepositoryToken(User),
useValue: mockUserRepository,
},
],
}).compile();
service = module.get<UsersService>(UsersService);
});
it('should be defined', () => {
expect(service).toBeDefined();
});
describe('findAll', () => {
it('should return an array of users', async () => {
const users = [{ id: 1, name: 'John' }];
mockUserRepository.find.mockResolvedValue(users);
const result = await service.findAll();
expect(result).toEqual(users);
expect(mockUserRepository.find).toHaveBeenCalled();
});
});
});
E2E Testing
import { Test, TestingModule } from '@nestjs/testing';
import { INestApplication } from '@nestjs/common';
import * as request from 'supertest';
import { AppModule } from './../src/app.module';
describe('UsersController (e2e)', () => {
let app: INestApplication;
beforeAll(async () => {
const moduleFixture: TestingModule = await Test.createTestingModule({
imports: [AppModule],
}).compile();
app = moduleFixture.createNestApplication();
await app.init();
});
afterAll(async () => {
await app.close();
});
it('/users (GET)', () => {
return request(app.getHttpServer())
.get('/users')
.expect(200)
.expect('Content-Type', /json/);
});
it('/users (POST)', () => {
return request(app.getHttpServer())
.post('/users')
.send({
name: 'John Doe',
email: 'john@example.com',
})
.expect(201)
.then(response => {
expect(response.body).toHaveProperty('id');
expect(response.body.name).toBe('John Doe');
});
});
});
Best Practices
1. Module Organization
// Feature-based organization
src/
├── modules/
│ ├── users/
│ │ ├── dto/
│ │ ├── entities/
│ │ ├── users.controller.ts
│ │ ├── users.service.ts
│ │ ├── users.module.ts
│ │ └── users.controller.spec.ts
│ └── products/
└── common/
├── guards/
├── interceptors/
├── pipes/
└── decorators/
2. DTOs and Validation
import { IsString, IsEmail, IsOptional, MinLength } from 'class-validator';
import { ApiProperty, ApiPropertyOptional } from '@nestjs/swagger';
export class CreateUserDto {
@ApiProperty({ example: 'John Doe' })
@IsString()
@MinLength(3)
name: string;
@ApiProperty({ example: 'john@example.com' })
@IsEmail()
email: string;
@ApiPropertyOptional()
@IsString()
@IsOptional()
bio?: string;
}
3. Environment Configuration
import { ConfigModule, ConfigService } from '@nestjs/config';
import * as Joi from 'joi';
@Module({
imports: [
ConfigModule.forRoot({
validationSchema: Joi.object({
NODE_ENV: Joi.string()
.valid('development', 'production', 'test')
.default('development'),
PORT: Joi.number().default(3000),
DATABASE_URL: Joi.string().required(),
JWT_SECRET: Joi.string().required(),
}),
}),
],
})
export class AppModule {}
4. Error Handling
import { HttpException, HttpStatus } from '@nestjs/common';
export class UserNotFoundException extends HttpException {
constructor(userId: number) {
super(`User with ID ${userId} not found`, HttpStatus.NOT_FOUND);
}
}
// Usage
throw new UserNotFoundException(id);
5. Logging
import { Logger, Injectable } from '@nestjs/common';
@Injectable()
export class UsersService {
private readonly logger = new Logger(UsersService.name);
async findAll() {
this.logger.log('Fetching all users');
try {
const users = await this.usersRepository.find();
this.logger.log(`Found ${users.length} users`);
return users;
} catch (error) {
this.logger.error('Failed to fetch users', error.stack);
throw error;
}
}
}
Production Deployment
Environment Variables
.env.production:
NODE_ENV=production
PORT=3000
DATABASE_URL=postgresql://user:password@localhost:5432/mydb
JWT_SECRET=your-secret-key
REDIS_URL=redis://localhost:6379
Docker Deployment
Dockerfile:
FROM node:18-alpine AS builder
WORKDIR /app
COPY package*.json ./
RUN npm ci
COPY . .
RUN npm run build
FROM node:18-alpine
WORKDIR /app
COPY package*.json ./
RUN npm ci --only=production
COPY --from=builder /app/dist ./dist
EXPOSE 3000
CMD ["node", "dist/main"]
docker-compose.yml:
version: '3.8'
services:
app:
build: .
ports:
- "3000:3000"
environment:
- NODE_ENV=production
- DATABASE_URL=postgresql://postgres:password@db:5432/mydb
depends_on:
- db
- redis
db:
image: postgres:15-alpine
environment:
POSTGRES_PASSWORD: password
POSTGRES_DB: mydb
volumes:
- postgres_data:/var/lib/postgresql/data
redis:
image: redis:7-alpine
ports:
- "6379:6379"
volumes:
postgres_data:
PM2 Deployment
ecosystem.config.js:
module.exports = {
apps: [{
name: 'nest-app',
script: 'dist/main.js',
instances: 'max',
exec_mode: 'cluster',
env: {
NODE_ENV: 'production',
},
}],
};
Health Checks
import { Controller, Get } from '@nestjs/common';
import {
HealthCheckService,
HealthCheck,
TypeOrmHealthIndicator,
MemoryHealthIndicator
} from '@nestjs/terminus';
@Controller('health')
export class HealthController {
constructor(
private health: HealthCheckService,
private db: TypeOrmHealthIndicator,
private memory: MemoryHealthIndicator,
) {}
@Get()
@HealthCheck()
check() {
return this.health.check([
() => this.db.pingCheck('database'),
() => this.memory.checkHeap('memory_heap', 150 * 1024 * 1024),
]);
}
}
Performance Optimization
// Enable compression
import * as compression from 'compression';
app.use(compression());
// Enable helmet for security
import helmet from 'helmet';
app.use(helmet());
// Rate limiting
import { ThrottlerModule } from '@nestjs/throttler';
@Module({
imports: [
ThrottlerModule.forRoot({
ttl: 60,
limit: 10,
}),
],
})
// Caching
import { CacheModule } from '@nestjs/cache-manager';
@Module({
imports: [
CacheModule.register({
ttl: 5,
max: 100,
}),
],
})
Resources
Official Documentation:
Learning Resources:
Community:
Tools:
Django
Django is a high-level Python web framework that encourages rapid development and clean, pragmatic design. Built by experienced developers, it takes care of much of the hassle of web development, so you can focus on writing your app without needing to reinvent the wheel. It follows the "batteries-included" philosophy and provides a complete solution for web development.
Table of Contents
- Introduction
- Installation and Setup
- Project Structure
- Models and Database
- Views
- URL Routing
- Templates
- Forms
- Authentication
- Django REST Framework
- Admin Interface
- Middleware
- Static and Media Files
- Testing
- Best Practices
- Production Deployment
Introduction
Key Features:
- Object-Relational Mapper (ORM) for database operations
- Automatic admin interface
- Clean, pragmatic URL design
- Template engine for dynamic HTML
- Built-in authentication and authorization
- Form handling and validation
- Security features (CSRF, XSS, SQL injection protection)
- Scalable architecture
- Excellent documentation
- Large ecosystem of packages
Use Cases:
- Content Management Systems (CMS)
- E-commerce platforms
- Social networks
- Data-driven web applications
- RESTful APIs
- Real-time applications
- Scientific computing platforms
- Financial applications
Philosophy:
- Don't Repeat Yourself (DRY)
- Explicit is better than implicit
- Loose coupling and tight cohesion
- Convention over configuration
Installation and Setup
Prerequisites
# Python 3.8+ required
python3 --version
pip --version
Virtual Environment Setup
# Create virtual environment
python3 -m venv venv
# Activate virtual environment
# On Linux/Mac:
source venv/bin/activate
# On Windows:
# venv\Scripts\activate
# Upgrade pip
pip install --upgrade pip
Install Django
# Install Django
pip install django
# Verify installation
django-admin --version
# Install additional packages
pip install python-decouple psycopg2-binary pillow django-cors-headers
Create New Project
# Create Django project
django-admin startproject myproject
# Navigate to project
cd myproject
# Create an app
python manage.py startapp myapp
# Run development server
python manage.py runserver
# Server runs on http://127.0.0.1:8000/
Initial Database Setup
# Create initial migrations
python manage.py makemigrations
# Apply migrations
python manage.py migrate
# Create superuser
python manage.py createsuperuser
Project Structure
myproject/
├── manage.py # Command-line utility
├── myproject/ # Project package
│ ├── __init__.py
│ ├── settings.py # Project settings
│ ├── urls.py # URL declarations
│ ├── asgi.py # ASGI entry point
│ └── wsgi.py # WSGI entry point
├── myapp/ # Application package
│ ├── migrations/ # Database migrations
│ ├── __init__.py
│ ├── admin.py # Admin configuration
│ ├── apps.py # App configuration
│ ├── models.py # Data models
│ ├── tests.py # Tests
│ ├── views.py # View functions/classes
│ └── urls.py # App URL patterns
├── templates/ # HTML templates
├── static/ # Static files (CSS, JS, images)
├── media/ # User-uploaded files
└── requirements.txt # Project dependencies
Settings Configuration
settings.py:
import os
from pathlib import Path
from decouple import config
BASE_DIR = Path(__file__).resolve().parent.parent
SECRET_KEY = config('SECRET_KEY', default='your-secret-key-here')
DEBUG = config('DEBUG', default=False, cast=bool)
ALLOWED_HOSTS = config('ALLOWED_HOSTS', default='localhost,127.0.0.1').split(',')
INSTALLED_APPS = [
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
'myapp', # Your app
'rest_framework', # For APIs
'corsheaders', # CORS headers
]
MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'corsheaders.middleware.CorsMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
]
ROOT_URLCONF = 'myproject.urls'
TEMPLATES = [
{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'DIRS': [BASE_DIR / 'templates'],
'APP_DIRS': True,
'OPTIONS': {
'context_processors': [
'django.template.context_processors.debug',
'django.template.context_processors.request',
'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages',
],
},
},
]
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.postgresql',
'NAME': config('DB_NAME', default='mydb'),
'USER': config('DB_USER', default='postgres'),
'PASSWORD': config('DB_PASSWORD', default='password'),
'HOST': config('DB_HOST', default='localhost'),
'PORT': config('DB_PORT', default='5432'),
}
}
STATIC_URL = '/static/'
STATIC_ROOT = BASE_DIR / 'staticfiles'
STATICFILES_DIRS = [BASE_DIR / 'static']
MEDIA_URL = '/media/'
MEDIA_ROOT = BASE_DIR / 'media'
Models and Database
Basic Model
from django.db import models
from django.contrib.auth.models import User
class Category(models.Model):
name = models.CharField(max_length=100)
slug = models.SlugField(unique=True)
description = models.TextField(blank=True)
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
class Meta:
verbose_name_plural = "Categories"
ordering = ['name']
def __str__(self):
return self.name
class Product(models.Model):
STATUS_CHOICES = [
('draft', 'Draft'),
('published', 'Published'),
('archived', 'Archived'),
]
name = models.CharField(max_length=200)
slug = models.SlugField(unique=True)
description = models.TextField()
price = models.DecimalField(max_digits=10, decimal_places=2)
category = models.ForeignKey(Category, on_delete=models.CASCADE, related_name='products')
image = models.ImageField(upload_to='products/', blank=True, null=True)
status = models.CharField(max_length=20, choices=STATUS_CHOICES, default='draft')
stock = models.IntegerField(default=0)
created_by = models.ForeignKey(User, on_delete=models.SET_NULL, null=True)
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
class Meta:
ordering = ['-created_at']
indexes = [
models.Index(fields=['slug']),
models.Index(fields=['status', 'created_at']),
]
def __str__(self):
return self.name
@property
def is_available(self):
return self.stock > 0 and self.status == 'published'
Advanced Models
from django.db import models
from django.core.validators import MinValueValidator, MaxValueValidator
from django.utils.text import slugify
class TimestampedModel(models.Model):
"""Abstract base model with timestamp fields"""
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
class Meta:
abstract = True
class Review(TimestampedModel):
product = models.ForeignKey('Product', on_delete=models.CASCADE, related_name='reviews')
user = models.ForeignKey(User, on_delete=models.CASCADE)
rating = models.IntegerField(
validators=[MinValueValidator(1), MaxValueValidator(5)]
)
title = models.CharField(max_length=200)
comment = models.TextField()
helpful_count = models.IntegerField(default=0)
class Meta:
unique_together = ['product', 'user']
ordering = ['-created_at']
def __str__(self):
return f"{self.user.username} - {self.product.name} ({self.rating}★)"
class Order(TimestampedModel):
ORDER_STATUS = [
('pending', 'Pending'),
('processing', 'Processing'),
('shipped', 'Shipped'),
('delivered', 'Delivered'),
('cancelled', 'Cancelled'),
]
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='orders')
status = models.CharField(max_length=20, choices=ORDER_STATUS, default='pending')
total_amount = models.DecimalField(max_digits=10, decimal_places=2)
shipping_address = models.TextField()
tracking_number = models.CharField(max_length=100, blank=True)
def __str__(self):
return f"Order #{self.id} - {self.user.username}"
class OrderItem(models.Model):
order = models.ForeignKey(Order, on_delete=models.CASCADE, related_name='items')
product = models.ForeignKey(Product, on_delete=models.PROTECT)
quantity = models.IntegerField(validators=[MinValueValidator(1)])
price = models.DecimalField(max_digits=10, decimal_places=2)
def __str__(self):
return f"{self.quantity}x {self.product.name}"
@property
def subtotal(self):
return self.quantity * self.price
QuerySet Operations
from django.db.models import Q, Count, Avg, Sum
# Basic queries
products = Product.objects.all()
product = Product.objects.get(id=1)
products = Product.objects.filter(status='published')
products = Product.objects.exclude(stock=0)
# Complex queries
products = Product.objects.filter(
Q(name__icontains='laptop') | Q(description__icontains='laptop'),
price__gte=500,
status='published'
).select_related('category').prefetch_related('reviews')
# Aggregation
from django.db.models import Count, Avg
stats = Product.objects.aggregate(
total_products=Count('id'),
avg_price=Avg('price'),
total_stock=Sum('stock')
)
# Annotation
categories = Category.objects.annotate(
product_count=Count('products'),
avg_price=Avg('products__price')
).filter(product_count__gt=0)
# Custom managers
class PublishedManager(models.Manager):
def get_queryset(self):
return super().get_queryset().filter(status='published')
class Product(models.Model):
# ... fields ...
objects = models.Manager()
published = PublishedManager()
# Usage
published_products = Product.published.all()
Migrations
# Create migrations
python manage.py makemigrations
# Apply migrations
python manage.py migrate
# Show migrations
python manage.py showmigrations
# Revert migration
python manage.py migrate myapp 0001
# Create empty migration
python manage.py makemigrations --empty myapp
Views
Function-Based Views
from django.shortcuts import render, get_object_or_404, redirect
from django.http import HttpResponse, JsonResponse
from django.contrib.auth.decorators import login_required
from .models import Product, Category
from .forms import ProductForm
def product_list(request):
products = Product.objects.filter(status='published')
categories = Category.objects.all()
context = {
'products': products,
'categories': categories,
}
return render(request, 'products/list.html', context)
def product_detail(request, slug):
product = get_object_or_404(Product, slug=slug, status='published')
related_products = Product.objects.filter(
category=product.category,
status='published'
).exclude(id=product.id)[:4]
context = {
'product': product,
'related_products': related_products,
}
return render(request, 'products/detail.html', context)
@login_required
def product_create(request):
if request.method == 'POST':
form = ProductForm(request.POST, request.FILES)
if form.is_valid():
product = form.save(commit=False)
product.created_by = request.user
product.save()
return redirect('product_detail', slug=product.slug)
else:
form = ProductForm()
return render(request, 'products/form.html', {'form': form})
def api_products(request):
products = Product.objects.filter(status='published').values(
'id', 'name', 'price', 'slug'
)
return JsonResponse(list(products), safe=False)
Class-Based Views
from django.views.generic import ListView, DetailView, CreateView, UpdateView, DeleteView
from django.contrib.auth.mixins import LoginRequiredMixin, UserPassesTestMixin
from django.urls import reverse_lazy
from .models import Product
class ProductListView(ListView):
model = Product
template_name = 'products/list.html'
context_object_name = 'products'
paginate_by = 12
def get_queryset(self):
queryset = Product.objects.filter(status='published')
# Filter by category
category_slug = self.request.GET.get('category')
if category_slug:
queryset = queryset.filter(category__slug=category_slug)
# Search
search_query = self.request.GET.get('q')
if search_query:
queryset = queryset.filter(
Q(name__icontains=search_query) |
Q(description__icontains=search_query)
)
return queryset.select_related('category')
def get_context_data(self, **kwargs):
context = super().get_context_data(**kwargs)
context['categories'] = Category.objects.all()
return context
class ProductDetailView(DetailView):
model = Product
template_name = 'products/detail.html'
context_object_name = 'product'
def get_queryset(self):
return Product.objects.filter(status='published').select_related('category')
class ProductCreateView(LoginRequiredMixin, CreateView):
model = Product
form_class = ProductForm
template_name = 'products/form.html'
success_url = reverse_lazy('product_list')
def form_valid(self, form):
form.instance.created_by = self.request.user
return super().form_valid(form)
class ProductUpdateView(LoginRequiredMixin, UserPassesTestMixin, UpdateView):
model = Product
form_class = ProductForm
template_name = 'products/form.html'
def test_func(self):
product = self.get_object()
return self.request.user == product.created_by or self.request.user.is_staff
def get_success_url(self):
return reverse_lazy('product_detail', kwargs={'slug': self.object.slug})
class ProductDeleteView(LoginRequiredMixin, UserPassesTestMixin, DeleteView):
model = Product
success_url = reverse_lazy('product_list')
def test_func(self):
product = self.get_object()
return self.request.user == product.created_by or self.request.user.is_staff
URL Routing
Project URLs
myproject/urls.py:
from django.contrib import admin
from django.urls import path, include
from django.conf import settings
from django.conf.urls.static import static
urlpatterns = [
path('admin/', admin.site.urls),
path('', include('myapp.urls')),
path('api/', include('myapp.api.urls')),
path('accounts/', include('django.contrib.auth.urls')),
]
if settings.DEBUG:
urlpatterns += static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT)
urlpatterns += static(settings.STATIC_URL, document_root=settings.STATIC_ROOT)
App URLs
myapp/urls.py:
from django.urls import path
from . import views
app_name = 'products'
urlpatterns = [
path('', views.ProductListView.as_view(), name='list'),
path('create/', views.ProductCreateView.as_view(), name='create'),
path('<slug:slug>/', views.ProductDetailView.as_view(), name='detail'),
path('<slug:slug>/edit/', views.ProductUpdateView.as_view(), name='edit'),
path('<slug:slug>/delete/', views.ProductDeleteView.as_view(), name='delete'),
# API endpoints
path('api/products/', views.api_products, name='api_list'),
]
URL Parameters
from django.urls import path, re_path
from . import views
urlpatterns = [
# String parameter
path('products/<slug:slug>/', views.product_detail),
# Integer parameter
path('products/<int:id>/', views.product_by_id),
# UUID parameter
path('orders/<uuid:order_id>/', views.order_detail),
# Regular expression
re_path(r'^articles/(?P<year>[0-9]{4})/$', views.year_archive),
]
Templates
Base Template
templates/base.html:
{% load static %}
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{% block title %}My Site{% endblock %}</title>
<link rel="stylesheet" href="{% static 'css/style.css' %}">
{% block extra_css %}{% endblock %}
</head>
<body>
<nav>
<a href="{% url 'products:list' %}">Products</a>
{% if user.is_authenticated %}
<a href="{% url 'products:create' %}">Add Product</a>
<span>Hello, {{ user.username }}!</span>
<a href="{% url 'logout' %}">Logout</a>
{% else %}
<a href="{% url 'login' %}">Login</a>
{% endif %}
</nav>
<main>
{% if messages %}
{% for message in messages %}
<div class="alert alert-{{ message.tags }}">
{{ message }}
</div>
{% endfor %}
{% endif %}
{% block content %}{% endblock %}
</main>
<footer>
<p>© 2024 My Site</p>
</footer>
<script src="{% static 'js/main.js' %}"></script>
{% block extra_js %}{% endblock %}
</body>
</html>
List Template
templates/products/list.html:
{% extends 'base.html' %}
{% load static %}
{% block title %}Products{% endblock %}
{% block content %}
<div class="products-container">
<h1>Products</h1>
<form method="get" class="search-form">
<input type="text" name="q" placeholder="Search products..." value="{{ request.GET.q }}">
<select name="category">
<option value="">All Categories</option>
{% for category in categories %}
<option value="{{ category.slug }}"
{% if request.GET.category == category.slug %}selected{% endif %}>
{{ category.name }}
</option>
{% endfor %}
</select>
<button type="submit">Search</button>
</form>
<div class="products-grid">
{% for product in products %}
<div class="product-card">
{% if product.image %}
<img src="{{ product.image.url }}" alt="{{ product.name }}">
{% else %}
<img src="{% static 'images/placeholder.png' %}" alt="No image">
{% endif %}
<h3>{{ product.name }}</h3>
<p>{{ product.description|truncatewords:20 }}</p>
<p class="price">${{ product.price }}</p>
<a href="{% url 'products:detail' product.slug %}" class="btn">View Details</a>
</div>
{% empty %}
<p>No products found.</p>
{% endfor %}
</div>
{% if is_paginated %}
<div class="pagination">
{% if page_obj.has_previous %}
<a href="?page=1">« First</a>
<a href="?page={{ page_obj.previous_page_number }}">Previous</a>
{% endif %}
<span class="current-page">
Page {{ page_obj.number }} of {{ page_obj.paginator.num_pages }}
</span>
{% if page_obj.has_next %}
<a href="?page={{ page_obj.next_page_number }}">Next</a>
<a href="?page={{ page_obj.paginator.num_pages }}">Last »</a>
{% endif %}
</div>
{% endif %}
</div>
{% endblock %}
Custom Template Tags
myapp/templatetags/custom_tags.py:
from django import template
from django.utils.html import format_html
from django.utils.safestring import mark_safe
register = template.Library()
@register.filter
def currency(value):
"""Format number as currency"""
return f"${value:,.2f}"
@register.simple_tag
def star_rating(rating):
"""Display star rating"""
full_stars = int(rating)
half_star = 1 if rating - full_stars >= 0.5 else 0
empty_stars = 5 - full_stars - half_star
stars = '★' * full_stars + '½' * half_star + '☆' * empty_stars
return format_html('<span class="rating">{}</span>', stars)
@register.inclusion_tag('includes/product_card.html')
def product_card(product):
"""Render product card"""
return {'product': product}
Forms
Model Form
from django import forms
from django.core.exceptions import ValidationError
from .models import Product, Review
class ProductForm(forms.ModelForm):
class Meta:
model = Product
fields = ['name', 'description', 'price', 'category', 'image', 'stock', 'status']
widgets = {
'description': forms.Textarea(attrs={'rows': 4}),
'price': forms.NumberInput(attrs={'step': '0.01'}),
}
def clean_price(self):
price = self.cleaned_data.get('price')
if price and price < 0:
raise ValidationError('Price cannot be negative')
return price
def clean_name(self):
name = self.cleaned_data.get('name')
if Product.objects.filter(name=name).exclude(pk=self.instance.pk).exists():
raise ValidationError('Product with this name already exists')
return name
class ReviewForm(forms.ModelForm):
class Meta:
model = Review
fields = ['rating', 'title', 'comment']
widgets = {
'rating': forms.RadioSelect(choices=[(i, f'{i}★') for i in range(1, 6)]),
'comment': forms.Textarea(attrs={'rows': 4, 'placeholder': 'Share your experience...'}),
}
class SearchForm(forms.Form):
query = forms.CharField(
max_length=100,
required=False,
widget=forms.TextInput(attrs={'placeholder': 'Search products...'})
)
category = forms.ModelChoiceField(
queryset=Category.objects.all(),
required=False,
empty_label='All Categories'
)
min_price = forms.DecimalField(required=False, min_value=0)
max_price = forms.DecimalField(required=False, min_value=0)
def clean(self):
cleaned_data = super().clean()
min_price = cleaned_data.get('min_price')
max_price = cleaned_data.get('max_price')
if min_price and max_price and min_price > max_price:
raise ValidationError('Minimum price cannot be greater than maximum price')
return cleaned_data
Custom Validation
from django import forms
from django.core.validators import EmailValidator, RegexValidator
class ContactForm(forms.Form):
name = forms.CharField(
max_length=100,
validators=[
RegexValidator(
regex=r'^[a-zA-Z\s]+$',
message='Name can only contain letters and spaces'
)
]
)
email = forms.EmailField(validators=[EmailValidator()])
phone = forms.CharField(
validators=[
RegexValidator(
regex=r'^\+?1?\d{9,15}$',
message='Enter a valid phone number'
)
]
)
message = forms.CharField(widget=forms.Textarea)
def clean_email(self):
email = self.cleaned_data.get('email')
if email and 'spam' in email.lower():
raise forms.ValidationError('This email appears to be spam')
return email
def send_email(self):
# Send email logic here
pass
Authentication
Login and Logout
from django.contrib.auth import authenticate, login, logout
from django.contrib.auth.forms import UserCreationForm
from django.shortcuts import render, redirect
from django.contrib import messages
def user_login(request):
if request.method == 'POST':
username = request.POST.get('username')
password = request.POST.get('password')
user = authenticate(request, username=username, password=password)
if user is not None:
login(request, user)
messages.success(request, f'Welcome back, {user.username}!')
return redirect('home')
else:
messages.error(request, 'Invalid username or password')
return render(request, 'registration/login.html')
def user_logout(request):
logout(request)
messages.info(request, 'You have been logged out')
return redirect('login')
def user_register(request):
if request.method == 'POST':
form = UserCreationForm(request.POST)
if form.is_valid():
user = form.save()
login(request, user)
messages.success(request, 'Registration successful!')
return redirect('home')
else:
form = UserCreationForm()
return render(request, 'registration/register.html', {'form': form})
Custom User Model
from django.contrib.auth.models import AbstractUser
from django.db import models
class CustomUser(AbstractUser):
email = models.EmailField(unique=True)
bio = models.TextField(blank=True)
avatar = models.ImageField(upload_to='avatars/', blank=True)
birth_date = models.DateField(null=True, blank=True)
phone = models.CharField(max_length=20, blank=True)
def __str__(self):
return self.username
# In settings.py
AUTH_USER_MODEL = 'myapp.CustomUser'
Permissions
from django.contrib.auth.decorators import login_required, permission_required
from django.contrib.auth.mixins import PermissionRequiredMixin
# Function-based view
@login_required
@permission_required('myapp.add_product', raise_exception=True)
def create_product(request):
# View logic
pass
# Class-based view
class ProductCreateView(LoginRequiredMixin, PermissionRequiredMixin, CreateView):
model = Product
permission_required = 'myapp.add_product'
# View logic
# Custom permission
class Product(models.Model):
# ... fields ...
class Meta:
permissions = [
("can_publish", "Can publish products"),
("can_feature", "Can feature products"),
]
# Check permission in code
if request.user.has_perm('myapp.can_publish'):
# User has permission
pass
Django REST Framework
Installation
pip install djangorestframework
Configuration
# settings.py
INSTALLED_APPS = [
# ...
'rest_framework',
]
REST_FRAMEWORK = {
'DEFAULT_AUTHENTICATION_CLASSES': [
'rest_framework.authentication.TokenAuthentication',
'rest_framework.authentication.SessionAuthentication',
],
'DEFAULT_PERMISSION_CLASSES': [
'rest_framework.permissions.IsAuthenticatedOrReadOnly',
],
'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
'PAGE_SIZE': 10,
}
Serializers
from rest_framework import serializers
from .models import Product, Category, Review
class CategorySerializer(serializers.ModelSerializer):
product_count = serializers.IntegerField(read_only=True)
class Meta:
model = Category
fields = ['id', 'name', 'slug', 'description', 'product_count']
class ProductSerializer(serializers.ModelSerializer):
category = CategorySerializer(read_only=True)
category_id = serializers.IntegerField(write_only=True)
reviews_count = serializers.SerializerMethodField()
average_rating = serializers.SerializerMethodField()
class Meta:
model = Product
fields = [
'id', 'name', 'slug', 'description', 'price',
'category', 'category_id', 'image', 'status', 'stock',
'reviews_count', 'average_rating', 'created_at'
]
read_only_fields = ['slug', 'created_at']
def get_reviews_count(self, obj):
return obj.reviews.count()
def get_average_rating(self, obj):
reviews = obj.reviews.all()
if reviews:
return sum(r.rating for r in reviews) / len(reviews)
return None
def validate_price(self, value):
if value < 0:
raise serializers.ValidationError('Price cannot be negative')
return value
class ReviewSerializer(serializers.ModelSerializer):
user = serializers.StringRelatedField(read_only=True)
class Meta:
model = Review
fields = ['id', 'user', 'rating', 'title', 'comment', 'created_at']
read_only_fields = ['user', 'created_at']
API Views
from rest_framework import viewsets, filters, status
from rest_framework.decorators import action, api_view, permission_classes
from rest_framework.response import Response
from rest_framework.permissions import IsAuthenticated, IsAuthenticatedOrReadOnly
from django_filters.rest_framework import DjangoFilterBackend
from .models import Product, Category
from .serializers import ProductSerializer, CategorySerializer
class ProductViewSet(viewsets.ModelViewSet):
queryset = Product.objects.all()
serializer_class = ProductSerializer
permission_classes = [IsAuthenticatedOrReadOnly]
filter_backends = [DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter]
filterset_fields = ['category', 'status']
search_fields = ['name', 'description']
ordering_fields = ['price', 'created_at']
lookup_field = 'slug'
def perform_create(self, serializer):
serializer.save(created_by=self.request.user)
@action(detail=True, methods=['post'])
def publish(self, request, slug=None):
product = self.get_object()
product.status = 'published'
product.save()
return Response({'status': 'product published'})
@action(detail=False, methods=['get'])
def featured(self, request):
featured_products = self.queryset.filter(status='published', stock__gt=0)[:10]
serializer = self.get_serializer(featured_products, many=True)
return Response(serializer.data)
class CategoryViewSet(viewsets.ReadOnlyModelViewSet):
queryset = Category.objects.annotate(product_count=Count('products'))
serializer_class = CategorySerializer
lookup_field = 'slug'
# Function-based API view
@api_view(['GET', 'POST'])
@permission_classes([IsAuthenticated])
def product_list_create(request):
if request.method == 'GET':
products = Product.objects.all()
serializer = ProductSerializer(products, many=True)
return Response(serializer.data)
elif request.method == 'POST':
serializer = ProductSerializer(data=request.data)
if serializer.is_valid():
serializer.save(created_by=request.user)
return Response(serializer.data, status=status.HTTP_201_CREATED)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
Admin Interface
Basic Admin Registration
from django.contrib import admin
from .models import Product, Category, Review, Order
admin.site.register(Category)
admin.site.register(Review)
Custom Admin
from django.contrib import admin
from django.utils.html import format_html
from .models import Product, Order, OrderItem
@admin.register(Category)
class CategoryAdmin(admin.ModelAdmin):
list_display = ['name', 'slug', 'product_count', 'created_at']
prepopulated_fields = {'slug': ('name',)}
search_fields = ['name']
def product_count(self, obj):
return obj.products.count()
product_count.short_description = 'Products'
@admin.register(Product)
class ProductAdmin(admin.ModelAdmin):
list_display = ['name', 'category', 'price', 'stock', 'status', 'image_preview', 'created_at']
list_filter = ['status', 'category', 'created_at']
search_fields = ['name', 'description']
prepopulated_fields = {'slug': ('name',)}
list_editable = ['price', 'stock', 'status']
readonly_fields = ['created_at', 'updated_at', 'image_preview']
fieldsets = (
('Basic Information', {
'fields': ('name', 'slug', 'description', 'category')
}),
('Pricing and Inventory', {
'fields': ('price', 'stock', 'status')
}),
('Media', {
'fields': ('image', 'image_preview')
}),
('Metadata', {
'fields': ('created_by', 'created_at', 'updated_at'),
'classes': ('collapse',)
}),
)
def image_preview(self, obj):
if obj.image:
return format_html('<img src="{}" width="100" height="100" />', obj.image.url)
return '-'
image_preview.short_description = 'Preview'
class OrderItemInline(admin.TabularInline):
model = OrderItem
extra = 0
readonly_fields = ['subtotal']
@admin.register(Order)
class OrderAdmin(admin.ModelAdmin):
list_display = ['id', 'user', 'status', 'total_amount', 'created_at']
list_filter = ['status', 'created_at']
search_fields = ['user__username', 'user__email', 'tracking_number']
inlines = [OrderItemInline]
readonly_fields = ['created_at', 'updated_at']
actions = ['mark_as_shipped']
def mark_as_shipped(self, request, queryset):
queryset.update(status='shipped')
mark_as_shipped.short_description = 'Mark selected orders as shipped'
Middleware
import time
import logging
from django.utils.deprecation import MiddlewareMixin
logger = logging.getLogger(__name__)
class RequestLoggingMiddleware(MiddlewareMixin):
def process_request(self, request):
request.start_time = time.time()
def process_response(self, request, response):
if hasattr(request, 'start_time'):
duration = time.time() - request.start_time
logger.info(f'{request.method} {request.path} - {response.status_code} - {duration:.2f}s')
return response
class CustomHeaderMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
response = self.get_response(request)
response['X-Custom-Header'] = 'My Custom Value'
return response
Static and Media Files
Settings
# settings.py
STATIC_URL = '/static/'
STATIC_ROOT = BASE_DIR / 'staticfiles'
STATICFILES_DIRS = [
BASE_DIR / 'static',
]
MEDIA_URL = '/media/'
MEDIA_ROOT = BASE_DIR / 'media'
# For production
STATICFILES_STORAGE = 'django.contrib.staticfiles.storage.ManifestStaticFilesStorage'
Collect Static Files
python manage.py collectstatic
Testing
Unit Tests
from django.test import TestCase
from django.contrib.auth.models import User
from .models import Product, Category
class ProductModelTest(TestCase):
def setUp(self):
self.user = User.objects.create_user(username='testuser', password='12345')
self.category = Category.objects.create(name='Electronics', slug='electronics')
self.product = Product.objects.create(
name='Laptop',
slug='laptop',
description='A great laptop',
price=999.99,
category=self.category,
stock=10,
created_by=self.user
)
def test_product_creation(self):
self.assertEqual(self.product.name, 'Laptop')
self.assertEqual(self.product.price, 999.99)
def test_product_is_available(self):
self.product.status = 'published'
self.assertTrue(self.product.is_available)
def test_product_str(self):
self.assertEqual(str(self.product), 'Laptop')
class ProductViewTest(TestCase):
def setUp(self):
self.user = User.objects.create_user(username='testuser', password='12345')
self.category = Category.objects.create(name='Electronics', slug='electronics')
self.product = Product.objects.create(
name='Laptop',
slug='laptop',
description='A great laptop',
price=999.99,
category=self.category,
status='published',
stock=10,
created_by=self.user
)
def test_product_list_view(self):
response = self.client.get('/products/')
self.assertEqual(response.status_code, 200)
self.assertContains(response, 'Laptop')
self.assertTemplateUsed(response, 'products/list.html')
def test_product_detail_view(self):
response = self.client.get(f'/products/{self.product.slug}/')
self.assertEqual(response.status_code, 200)
self.assertContains(response, self.product.name)
def test_product_create_requires_login(self):
response = self.client.get('/products/create/')
self.assertEqual(response.status_code, 302)
def test_product_create_authenticated(self):
self.client.login(username='testuser', password='12345')
response = self.client.post('/products/create/', {
'name': 'New Product',
'slug': 'new-product',
'description': 'Description',
'price': 99.99,
'category': self.category.id,
'stock': 5,
'status': 'draft'
})
self.assertEqual(response.status_code, 302)
self.assertTrue(Product.objects.filter(name='New Product').exists())
Best Practices
1. Settings Organization
# settings/
# ├── __init__.py
# ├── base.py
# ├── development.py
# ├── production.py
# └── testing.py
# base.py - Common settings
# development.py
from .base import *
DEBUG = True
ALLOWED_HOSTS = ['localhost', '127.0.0.1']
# production.py
from .base import *
DEBUG = False
ALLOWED_HOSTS = config('ALLOWED_HOSTS').split(',')
2. Use Environment Variables
from decouple import config
SECRET_KEY = config('SECRET_KEY')
DEBUG = config('DEBUG', default=False, cast=bool)
DATABASE_URL = config('DATABASE_URL')
3. Query Optimization
# Use select_related for foreign keys
products = Product.objects.select_related('category').all()
# Use prefetch_related for many-to-many and reverse foreign keys
products = Product.objects.prefetch_related('reviews').all()
# Only get needed fields
products = Product.objects.values('id', 'name', 'price')
# Use iterator for large querysets
for product in Product.objects.iterator():
# Process product
pass
4. Security
# settings.py
SECURE_SSL_REDIRECT = True
SESSION_COOKIE_SECURE = True
CSRF_COOKIE_SECURE = True
SECURE_BROWSER_XSS_FILTER = True
SECURE_CONTENT_TYPE_NOSNIFF = True
X_FRAME_OPTIONS = 'DENY'
# Use Django's CSRF protection
# Always validate and sanitize user input
# Use parameterized queries (Django ORM does this by default)
Production Deployment
Requirements File
pip freeze > requirements.txt
requirements.txt:
Django==4.2.7
psycopg2-binary==2.9.9
python-decouple==3.8
Pillow==10.1.0
gunicorn==21.2.0
django-cors-headers==4.3.1
djangorestframework==3.14.0
Gunicorn Configuration
gunicorn.conf.py:
bind = '0.0.0.0:8000'
workers = 4
threads = 2
worker_class = 'sync'
worker_connections = 1000
timeout = 30
keepalive = 2
accesslog = '-'
errorlog = '-'
loglevel = 'info'
Docker Deployment
Dockerfile:
FROM python:3.11-slim
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
RUN python manage.py collectstatic --noinput
EXPOSE 8000
CMD ["gunicorn", "--config", "gunicorn.conf.py", "myproject.wsgi:application"]
docker-compose.yml:
version: '3.8'
services:
web:
build: .
command: gunicorn myproject.wsgi:application --bind 0.0.0.0:8000
volumes:
- ./:/app
- static_volume:/app/staticfiles
- media_volume:/app/media
ports:
- "8000:8000"
env_file:
- .env
depends_on:
- db
db:
image: postgres:15-alpine
volumes:
- postgres_data:/var/lib/postgresql/data
environment:
- POSTGRES_DB=mydb
- POSTGRES_USER=postgres
- POSTGRES_PASSWORD=password
nginx:
image: nginx:alpine
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf
- static_volume:/app/staticfiles
- media_volume:/app/media
ports:
- "80:80"
depends_on:
- web
volumes:
postgres_data:
static_volume:
media_volume:
Nginx Configuration
nginx.conf:
upstream django {
server web:8000;
}
server {
listen 80;
server_name example.com;
location / {
proxy_pass http://django;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
location /static/ {
alias /app/staticfiles/;
}
location /media/ {
alias /app/media/;
}
}
Resources
Official Documentation:
Learning Resources:
Community:
Tools and Packages:
Flask
Flask is a lightweight WSGI web application framework for Python. It's designed to make getting started quick and easy, with the ability to scale up to complex applications. Flask is often called a "microframework" because it doesn't require particular tools or libraries, giving developers flexibility in choosing their tools and architecture.
Table of Contents
- Introduction
- Installation and Setup
- Basic Application
- Routing
- Request and Response
- Templates with Jinja2
- Forms and Validation
- Database Integration
- Authentication
- RESTful APIs
- Blueprints
- Error Handling
- File Uploads
- Testing
- Best Practices
- Production Deployment
Introduction
Key Features:
- Minimal core with extensions for added functionality
- Built-in development server and debugger
- Integrated unit testing support
- RESTful request dispatching
- Jinja2 templating engine
- Secure cookies for client-side sessions
- WSGI 1.0 compliant
- Unicode-based
- Extensive documentation
- Active community
Use Cases:
- RESTful APIs
- Microservices
- Prototypes and MVPs
- Small to medium web applications
- Backend for single-page applications
- Data science dashboards
- Webhook handlers
- Static sites with dynamic content
Philosophy:
- Simplicity and flexibility
- Explicit over implicit
- Start small, scale when needed
- No forced dependencies
- Easy to extend
Installation and Setup
Prerequisites
# Python 3.7+ required
python3 --version
pip --version
Virtual Environment
# Create virtual environment
python3 -m venv venv
# Activate virtual environment
# Linux/Mac:
source venv/bin/activate
# Windows:
# venv\Scripts\activate
# Upgrade pip
pip install --upgrade pip
Install Flask
# Install Flask
pip install Flask
# Install common extensions
pip install Flask-SQLAlchemy Flask-Migrate Flask-Login Flask-WTF
pip install Flask-CORS Flask-JWT-Extended python-dotenv
Project Structure
flask-app/
├── app/
│ ├── __init__.py
│ ├── models.py
│ ├── routes.py
│ ├── forms.py
│ ├── templates/
│ │ ├── base.html
│ │ └── index.html
│ ├── static/
│ │ ├── css/
│ │ ├── js/
│ │ └── images/
│ └── blueprints/
│ ├── auth/
│ └── api/
├── tests/
│ └── test_routes.py
├── migrations/
├── config.py
├── .env
├── .flaskenv
├── requirements.txt
└── run.py
Basic Application
Minimal Flask App
from flask import Flask
app = Flask(__name__)
@app.route('/')
def hello():
return 'Hello, World!'
if __name__ == '__main__':
app.run(debug=True)
Application Factory Pattern
app/init.py:
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from flask_login import LoginManager
from config import Config
db = SQLAlchemy()
migrate = Migrate()
login_manager = LoginManager()
def create_app(config_class=Config):
app = Flask(__name__)
app.config.from_object(config_class)
# Initialize extensions
db.init_app(app)
migrate.init_app(app, db)
login_manager.init_app(app)
login_manager.login_view = 'auth.login'
# Register blueprints
from app.blueprints.auth import auth_bp
from app.blueprints.main import main_bp
from app.blueprints.api import api_bp
app.register_blueprint(auth_bp, url_prefix='/auth')
app.register_blueprint(main_bp)
app.register_blueprint(api_bp, url_prefix='/api')
return app
from app import models
config.py:
import os
from dotenv import load_dotenv
basedir = os.path.abspath(os.path.dirname(__file__))
load_dotenv(os.path.join(basedir, '.env'))
class Config:
SECRET_KEY = os.environ.get('SECRET_KEY') or 'dev-secret-key'
SQLALCHEMY_DATABASE_URI = os.environ.get('DATABASE_URL') or \
'sqlite:///' + os.path.join(basedir, 'app.db')
SQLALCHEMY_TRACK_MODIFICATIONS = False
# File upload settings
UPLOAD_FOLDER = os.path.join(basedir, 'uploads')
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB max file size
class DevelopmentConfig(Config):
DEBUG = True
class ProductionConfig(Config):
DEBUG = False
class TestingConfig(Config):
TESTING = True
SQLALCHEMY_DATABASE_URI = 'sqlite:///:memory:'
run.py:
from app import create_app, db
from app.models import User, Post
app = create_app()
@app.shell_context_processor
def make_shell_context():
return {'db': db, 'User': User, 'Post': Post}
if __name__ == '__main__':
app.run(debug=True)
Routing
Basic Routes
from flask import Flask
app = Flask(__name__)
@app.route('/')
def index():
return 'Home Page'
@app.route('/about')
def about():
return 'About Page'
# Route with variable
@app.route('/user/<username>')
def show_user(username):
return f'User: {username}'
# Route with type converter
@app.route('/post/<int:post_id>')
def show_post(post_id):
return f'Post ID: {post_id}'
# Route with multiple types
@app.route('/path/<path:subpath>')
def show_subpath(subpath):
return f'Subpath: {subpath}'
HTTP Methods
from flask import request, jsonify
@app.route('/login', methods=['GET', 'POST'])
def login():
if request.method == 'POST':
username = request.form.get('username')
password = request.form.get('password')
# Process login
return {'message': 'Login successful'}
return 'Login form'
# Separate methods
@app.get('/users')
def get_users():
return jsonify([])
@app.post('/users')
def create_user():
data = request.get_json()
return jsonify(data), 201
@app.put('/users/<int:id>')
def update_user(id):
data = request.get_json()
return jsonify(data)
@app.delete('/users/<int:id>')
def delete_user(id):
return '', 204
URL Building
from flask import url_for, redirect
@app.route('/admin')
def admin():
return 'Admin Page'
@app.route('/redirect-to-admin')
def redirect_to_admin():
return redirect(url_for('admin'))
@app.route('/user/<username>')
def profile(username):
return f'Profile: {username}'
# Generate URL
with app.test_request_context():
print(url_for('admin')) # /admin
print(url_for('profile', username='john')) # /user/john
print(url_for('static', filename='style.css')) # /static/style.css
Request and Response
Request Object
from flask import request, jsonify
@app.route('/search')
def search():
# Query parameters
query = request.args.get('q', '')
page = request.args.get('page', 1, type=int)
return f'Searching for: {query}, Page: {page}'
@app.route('/submit', methods=['POST'])
def submit():
# Form data
name = request.form.get('name')
email = request.form.get('email')
# JSON data
if request.is_json:
data = request.get_json()
name = data.get('name')
email = data.get('email')
# Files
if 'file' in request.files:
file = request.files['file']
if file.filename:
file.save(f'uploads/{file.filename}')
# Headers
user_agent = request.headers.get('User-Agent')
auth_token = request.headers.get('Authorization')
# Cookies
session_id = request.cookies.get('session_id')
return jsonify({
'name': name,
'email': email,
'user_agent': user_agent
})
Response Object
from flask import make_response, jsonify, render_template, send_file
@app.route('/json')
def json_response():
return jsonify({
'status': 'success',
'data': {'id': 1, 'name': 'John'}
})
@app.route('/custom')
def custom_response():
response = make_response('Custom response', 200)
response.headers['X-Custom-Header'] = 'Value'
response.set_cookie('user_id', '123', max_age=3600)
return response
@app.route('/download')
def download():
return send_file('path/to/file.pdf', as_attachment=True)
@app.route('/stream')
def stream():
def generate():
for i in range(10):
yield f'data: {i}\n\n'
return app.response_class(generate(), mimetype='text/event-stream')
Templates with Jinja2
Base Template
templates/base.html:
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{% block title %}My App{% endblock %}</title>
<link rel="stylesheet" href="{{ url_for('static', filename='css/style.css') }}">
{% block extra_css %}{% endblock %}
</head>
<body>
<nav>
<a href="{{ url_for('index') }}">Home</a>
{% if current_user.is_authenticated %}
<a href="{{ url_for('profile') }}">Profile</a>
<a href="{{ url_for('logout') }}">Logout</a>
{% else %}
<a href="{{ url_for('login') }}">Login</a>
<a href="{{ url_for('register') }}">Register</a>
{% endif %}
</nav>
<main>
{% with messages = get_flashed_messages(with_categories=true) %}
{% if messages %}
{% for category, message in messages %}
<div class="alert alert-{{ category }}">{{ message }}</div>
{% endfor %}
{% endif %}
{% endwith %}
{% block content %}{% endblock %}
</main>
<footer>
<p>© 2024 My App</p>
</footer>
<script src="{{ url_for('static', filename='js/main.js') }}"></script>
{% block extra_js %}{% endblock %}
</body>
</html>
Child Template
templates/index.html:
{% extends 'base.html' %}
{% block title %}Home - {{ super() }}{% endblock %}
{% block content %}
<h1>Welcome to {{ app_name }}</h1>
{% if users %}
<ul>
{% for user in users %}
<li>
<a href="{{ url_for('show_user', username=user.username) }}">
{{ user.username }}
</a>
{% if user.is_admin %}
<span class="badge">Admin</span>
{% endif %}
</li>
{% endfor %}
</ul>
{% else %}
<p>No users found.</p>
{% endif %}
<!-- Macros -->
{% macro render_user(user) %}
<div class="user-card">
<h3>{{ user.username }}</h3>
<p>{{ user.email }}</p>
</div>
{% endmacro %}
{% for user in users %}
{{ render_user(user) }}
{% endfor %}
{% endblock %}
Template Filters and Functions
from flask import Flask
from datetime import datetime
app = Flask(__name__)
@app.template_filter('datetimeformat')
def datetimeformat(value, format='%Y-%m-%d %H:%M'):
return value.strftime(format)
@app.template_filter('currency')
def currency(value):
return f'${value:,.2f}'
@app.context_processor
def utility_processor():
def format_price(amount):
return f'${amount:,.2f}'
return dict(format_price=format_price)
# Usage in template:
# {{ order.created_at|datetimeformat }}
# {{ product.price|currency }}
# {{ format_price(100.50) }}
Forms and Validation
Flask-WTF Forms
from flask_wtf import FlaskForm
from wtforms import StringField, PasswordField, TextAreaField, SelectField, BooleanField
from wtforms.validators import DataRequired, Email, Length, EqualTo, ValidationError
from app.models import User
class RegistrationForm(FlaskForm):
username = StringField('Username',
validators=[DataRequired(), Length(min=3, max=20)])
email = StringField('Email',
validators=[DataRequired(), Email()])
password = PasswordField('Password',
validators=[DataRequired(), Length(min=8)])
confirm_password = PasswordField('Confirm Password',
validators=[DataRequired(), EqualTo('password')])
def validate_username(self, username):
user = User.query.filter_by(username=username.data).first()
if user:
raise ValidationError('Username already exists')
def validate_email(self, email):
user = User.query.filter_by(email=email.data).first()
if user:
raise ValidationError('Email already registered')
class LoginForm(FlaskForm):
email = StringField('Email', validators=[DataRequired(), Email()])
password = PasswordField('Password', validators=[DataRequired()])
remember = BooleanField('Remember Me')
class PostForm(FlaskForm):
title = StringField('Title', validators=[DataRequired(), Length(max=100)])
content = TextAreaField('Content', validators=[DataRequired()])
category = SelectField('Category', coerce=int)
def __init__(self, *args, **kwargs):
super(PostForm, self).__init__(*args, **kwargs)
from app.models import Category
self.category.choices = [(c.id, c.name) for c in Category.query.all()]
Form Handling in Views
from flask import render_template, redirect, url_for, flash
from app import db
from app.forms import RegistrationForm, LoginForm
from app.models import User
@app.route('/register', methods=['GET', 'POST'])
def register():
form = RegistrationForm()
if form.validate_on_submit():
user = User(username=form.username.data, email=form.email.data)
user.set_password(form.password.data)
db.session.add(user)
db.session.commit()
flash('Registration successful!', 'success')
return redirect(url_for('login'))
return render_template('register.html', form=form)
@app.route('/login', methods=['GET', 'POST'])
def login():
form = LoginForm()
if form.validate_on_submit():
user = User.query.filter_by(email=form.email.data).first()
if user and user.check_password(form.password.data):
login_user(user, remember=form.remember.data)
flash('Login successful!', 'success')
next_page = request.args.get('next')
return redirect(next_page) if next_page else redirect(url_for('index'))
flash('Invalid email or password', 'danger')
return render_template('login.html', form=form)
Form Template
templates/register.html:
{% extends 'base.html' %}
{% block content %}
<h2>Register</h2>
<form method="POST" novalidate>
{{ form.hidden_tag() }}
<div class="form-group">
{{ form.username.label }}
{{ form.username(class='form-control') }}
{% if form.username.errors %}
<div class="errors">
{% for error in form.username.errors %}
<span>{{ error }}</span>
{% endfor %}
</div>
{% endif %}
</div>
<div class="form-group">
{{ form.email.label }}
{{ form.email(class='form-control') }}
{% if form.email.errors %}
<div class="errors">
{% for error in form.email.errors %}
<span>{{ error }}</span>
{% endfor %}
</div>
{% endif %}
</div>
<div class="form-group">
{{ form.password.label }}
{{ form.password(class='form-control') }}
{% if form.password.errors %}
<div class="errors">
{% for error in form.password.errors %}
<span>{{ error }}</span>
{% endfor %}
</div>
{% endif %}
</div>
<div class="form-group">
{{ form.confirm_password.label }}
{{ form.confirm_password(class='form-control') }}
</div>
<button type="submit" class="btn btn-primary">Register</button>
</form>
{% endblock %}
Database Integration
SQLAlchemy Models
app/models.py:
from app import db, login_manager
from datetime import datetime
from werkzeug.security import generate_password_hash, check_password_hash
from flask_login import UserMixin
@login_manager.user_loader
def load_user(user_id):
return User.query.get(int(user_id))
class User(UserMixin, db.Model):
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(80), unique=True, nullable=False, index=True)
email = db.Column(db.String(120), unique=True, nullable=False, index=True)
password_hash = db.Column(db.String(128))
created_at = db.Column(db.DateTime, default=datetime.utcnow)
posts = db.relationship('Post', backref='author', lazy='dynamic')
def set_password(self, password):
self.password_hash = generate_password_hash(password)
def check_password(self, password):
return check_password_hash(self.password_hash, password)
def __repr__(self):
return f'<User {self.username}>'
class Post(db.Model):
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String(100), nullable=False)
content = db.Column(db.Text, nullable=False)
slug = db.Column(db.String(120), unique=True, index=True)
published = db.Column(db.Boolean, default=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
category_id = db.Column(db.Integer, db.ForeignKey('category.id'))
def __repr__(self):
return f'<Post {self.title}>'
class Category(db.Model):
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(50), unique=True, nullable=False)
posts = db.relationship('Post', backref='category', lazy=True)
Database Operations
from app import db
from app.models import User, Post
# Create
user = User(username='john', email='john@example.com')
user.set_password('password123')
db.session.add(user)
db.session.commit()
# Read
users = User.query.all()
user = User.query.filter_by(username='john').first()
user = User.query.get(1)
posts = Post.query.filter_by(published=True).order_by(Post.created_at.desc()).all()
# Update
user = User.query.get(1)
user.email = 'newemail@example.com'
db.session.commit()
# Delete
user = User.query.get(1)
db.session.delete(user)
db.session.commit()
# Complex queries
from sqlalchemy import or_, and_
posts = Post.query.filter(
or_(
Post.title.like('%python%'),
Post.content.like('%python%')
),
Post.published == True
).all()
# Pagination
page = request.args.get('page', 1, type=int)
posts = Post.query.order_by(Post.created_at.desc()).paginate(
page=page, per_page=10, error_out=False
)
Migrations
# Initialize migrations
flask db init
# Create migration
flask db migrate -m "Add user table"
# Apply migration
flask db upgrade
# Rollback
flask db downgrade
Authentication
Flask-Login Setup
from flask_login import LoginManager, login_user, logout_user, login_required, current_user
from app import app, db
from app.models import User
login_manager = LoginManager()
login_manager.init_app(app)
login_manager.login_view = 'login'
login_manager.login_message = 'Please log in to access this page.'
@login_manager.user_loader
def load_user(user_id):
return User.query.get(int(user_id))
@app.route('/login', methods=['GET', 'POST'])
def login():
if current_user.is_authenticated:
return redirect(url_for('index'))
form = LoginForm()
if form.validate_on_submit():
user = User.query.filter_by(email=form.email.data).first()
if user and user.check_password(form.password.data):
login_user(user, remember=form.remember.data)
next_page = request.args.get('next')
return redirect(next_page) if next_page else redirect(url_for('index'))
flash('Invalid email or password', 'danger')
return render_template('login.html', form=form)
@app.route('/logout')
@login_required
def logout():
logout_user()
flash('You have been logged out', 'info')
return redirect(url_for('index'))
@app.route('/profile')
@login_required
def profile():
return render_template('profile.html', user=current_user)
JWT Authentication
from flask_jwt_extended import JWTManager, create_access_token, jwt_required, get_jwt_identity
app.config['JWT_SECRET_KEY'] = 'your-secret-key'
jwt = JWTManager(app)
@app.route('/api/auth/login', methods=['POST'])
def api_login():
data = request.get_json()
email = data.get('email')
password = data.get('password')
user = User.query.filter_by(email=email).first()
if user and user.check_password(password):
access_token = create_access_token(identity=user.id)
return jsonify(access_token=access_token), 200
return jsonify({'message': 'Invalid credentials'}), 401
@app.route('/api/protected', methods=['GET'])
@jwt_required()
def protected():
current_user_id = get_jwt_identity()
user = User.query.get(current_user_id)
return jsonify(username=user.username), 200
RESTful APIs
Flask-RESTful
from flask import Flask
from flask_restful import Resource, Api, reqparse, fields, marshal_with
from app import db
from app.models import Post
app = Flask(__name__)
api = Api(app)
# Request parser
post_parser = reqparse.RequestParser()
post_parser.add_argument('title', type=str, required=True, help='Title is required')
post_parser.add_argument('content', type=str, required=True)
post_parser.add_argument('category_id', type=int)
# Resource fields for serialization
post_fields = {
'id': fields.Integer,
'title': fields.String,
'content': fields.String,
'created_at': fields.DateTime(dt_format='iso8601'),
'author': fields.Nested({
'id': fields.Integer,
'username': fields.String
})
}
class PostListAPI(Resource):
@marshal_with(post_fields)
def get(self):
posts = Post.query.all()
return posts
@marshal_with(post_fields)
def post(self):
args = post_parser.parse_args()
post = Post(
title=args['title'],
content=args['content'],
user_id=current_user.id,
category_id=args.get('category_id')
)
db.session.add(post)
db.session.commit()
return post, 201
class PostAPI(Resource):
@marshal_with(post_fields)
def get(self, post_id):
post = Post.query.get_or_404(post_id)
return post
@marshal_with(post_fields)
def put(self, post_id):
post = Post.query.get_or_404(post_id)
args = post_parser.parse_args()
post.title = args['title']
post.content = args['content']
db.session.commit()
return post
def delete(self, post_id):
post = Post.query.get_or_404(post_id)
db.session.delete(post)
db.session.commit()
return '', 204
api.add_resource(PostListAPI, '/api/posts')
api.add_resource(PostAPI, '/api/posts/<int:post_id>')
Blueprints
Creating Blueprints
app/blueprints/auth/init.py:
from flask import Blueprint
auth_bp = Blueprint('auth', __name__)
from app.blueprints.auth import routes
app/blueprints/auth/routes.py:
from flask import render_template, redirect, url_for, flash, request
from flask_login import login_user, logout_user, login_required
from app.blueprints.auth import auth_bp
from app import db
from app.models import User
from app.forms import LoginForm, RegistrationForm
@auth_bp.route('/login', methods=['GET', 'POST'])
def login():
form = LoginForm()
if form.validate_on_submit():
user = User.query.filter_by(email=form.email.data).first()
if user and user.check_password(form.password.data):
login_user(user, remember=form.remember.data)
return redirect(url_for('main.index'))
flash('Invalid credentials', 'danger')
return render_template('auth/login.html', form=form)
@auth_bp.route('/logout')
@login_required
def logout():
logout_user()
return redirect(url_for('main.index'))
@auth_bp.route('/register', methods=['GET', 'POST'])
def register():
form = RegistrationForm()
if form.validate_on_submit():
user = User(username=form.username.data, email=form.email.data)
user.set_password(form.password.data)
db.session.add(user)
db.session.commit()
flash('Registration successful!', 'success')
return redirect(url_for('auth.login'))
return render_template('auth/register.html', form=form)
Registering Blueprints
app/init.py:
def create_app():
app = Flask(__name__)
# Register blueprints
from app.blueprints.auth import auth_bp
from app.blueprints.main import main_bp
from app.blueprints.api import api_bp
app.register_blueprint(auth_bp, url_prefix='/auth')
app.register_blueprint(main_bp)
app.register_blueprint(api_bp, url_prefix='/api')
return app
Error Handling
from flask import render_template, jsonify
@app.errorhandler(404)
def not_found_error(error):
if request.path.startswith('/api/'):
return jsonify({'error': 'Not found'}), 404
return render_template('errors/404.html'), 404
@app.errorhandler(500)
def internal_error(error):
db.session.rollback()
if request.path.startswith('/api/'):
return jsonify({'error': 'Internal server error'}), 500
return render_template('errors/500.html'), 500
@app.errorhandler(403)
def forbidden_error(error):
return jsonify({'error': 'Forbidden'}), 403
# Custom exception
class ValidationError(Exception):
pass
@app.errorhandler(ValidationError)
def handle_validation_error(error):
return jsonify({'error': str(error)}), 400
File Uploads
import os
from werkzeug.utils import secure_filename
from flask import request, flash, redirect, url_for
ALLOWED_EXTENSIONS = {'txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif'}
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route('/upload', methods=['GET', 'POST'])
@login_required
def upload_file():
if request.method == 'POST':
if 'file' not in request.files:
flash('No file part', 'danger')
return redirect(request.url)
file = request.files['file']
if file.filename == '':
flash('No selected file', 'danger')
return redirect(request.url)
if file and allowed_file(file.filename):
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
flash('File uploaded successfully', 'success')
return redirect(url_for('index'))
return render_template('upload.html')
Testing
import unittest
from app import create_app, db
from app.models import User, Post
from config import TestingConfig
class UserModelTestCase(unittest.TestCase):
def setUp(self):
self.app = create_app(TestingConfig)
self.app_context = self.app.app_context()
self.app_context.push()
db.create_all()
def tearDown(self):
db.session.remove()
db.drop_all()
self.app_context.pop()
def test_password_hashing(self):
user = User(username='john', email='john@example.com')
user.set_password('password')
self.assertFalse(user.check_password('wrong'))
self.assertTrue(user.check_password('password'))
class RoutesTestCase(unittest.TestCase):
def setUp(self):
self.app = create_app(TestingConfig)
self.client = self.app.test_client()
self.app_context = self.app.app_context()
self.app_context.push()
db.create_all()
def tearDown(self):
db.session.remove()
db.drop_all()
self.app_context.pop()
def test_index_page(self):
response = self.client.get('/')
self.assertEqual(response.status_code, 200)
def test_login(self):
# Create user
user = User(username='test', email='test@example.com')
user.set_password('password')
db.session.add(user)
db.session.commit()
# Test login
response = self.client.post('/auth/login', data={
'email': 'test@example.com',
'password': 'password'
}, follow_redirects=True)
self.assertEqual(response.status_code, 200)
if __name__ == '__main__':
unittest.main()
Best Practices
1. Application Factory
def create_app(config_class=Config):
app = Flask(__name__)
app.config.from_object(config_class)
db.init_app(app)
migrate.init_app(app, db)
return app
2. Configuration Management
# Use environment variables
from dotenv import load_dotenv
load_dotenv()
SECRET_KEY = os.environ.get('SECRET_KEY')
DATABASE_URL = os.environ.get('DATABASE_URL')
3. Error Handling
# Always handle exceptions
try:
# Database operation
db.session.commit()
except Exception as e:
db.session.rollback()
app.logger.error(f'Error: {str(e)}')
flash('An error occurred', 'danger')
4. Security
# CSRF protection
from flask_wtf.csrf import CSRFProtect
csrf = CSRFProtect(app)
# Security headers
from flask_talisman import Talisman
Talisman(app, content_security_policy=None)
# Rate limiting
from flask_limiter import Limiter
limiter = Limiter(app, key_func=lambda: request.remote_addr)
@app.route('/api/data')
@limiter.limit("5 per minute")
def api_data():
return jsonify({'data': []})
Production Deployment
Requirements
requirements.txt:
Flask==3.0.0
Flask-SQLAlchemy==3.1.1
Flask-Migrate==4.0.5
Flask-Login==0.6.3
Flask-WTF==1.2.1
Flask-CORS==4.0.0
Flask-JWT-Extended==4.5.3
python-dotenv==1.0.0
gunicorn==21.2.0
psycopg2-binary==2.9.9
Gunicorn
# Install
pip install gunicorn
# Run
gunicorn -w 4 -b 0.0.0.0:8000 "app:create_app()"
# With config file
gunicorn -c gunicorn.conf.py "app:create_app()"
gunicorn.conf.py:
bind = '0.0.0.0:8000'
workers = 4
threads = 2
timeout = 30
accesslog = '-'
errorlog = '-'
loglevel = 'info'
Docker
Dockerfile:
FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
CMD ["gunicorn", "-c", "gunicorn.conf.py", "app:create_app()"]
docker-compose.yml:
version: '3.8'
services:
web:
build: .
ports:
- "8000:8000"
environment:
- FLASK_ENV=production
- DATABASE_URL=postgresql://user:pass@db:5432/mydb
depends_on:
- db
db:
image: postgres:15-alpine
environment:
POSTGRES_PASSWORD: password
POSTGRES_DB: mydb
volumes:
- postgres_data:/var/lib/postgresql/data
volumes:
postgres_data:
Resources
Official Documentation:
Extensions:
Community:
Books:
- Flask Web Development by Miguel Grinberg
- Flask Framework Cookbook
FastAPI
FastAPI is a modern, fast (high-performance) web framework for building APIs with Python 3.7+ based on standard Python type hints. It's designed to be easy to use and learn while providing production-ready code with automatic API documentation, data validation, and serialization.
Table of Contents
- Introduction
- Installation and Setup
- Basic Application
- Path Operations
- Request and Response Models
- Dependency Injection
- Database Integration
- Authentication and Security
- Background Tasks
- WebSockets
- File Operations
- Testing
- Best Practices
- Production Deployment
Introduction
Key Features:
- Fast performance (on par with NodeJS and Go)
- Automatic interactive API documentation (Swagger UI and ReDoc)
- Based on standard Python type hints
- Data validation using Pydantic
- Asynchronous support with async/await
- Dependency injection system
- OAuth2 and JWT authentication built-in
- WebSocket support
- GraphQL support
- Minimal boilerplate code
- Production-ready with automatic error responses
Use Cases:
- RESTful APIs
- Microservices
- Real-time applications
- Machine learning model serving
- Data science APIs
- Backend for mobile/web applications
- API gateways
- WebSocket servers
Why FastAPI?
- Fastest Python framework according to benchmarks
- Reduces bugs by ~40% with type checking
- Easy to learn, fast to code
- Editor support with autocomplete
- Reduces code duplication
Installation and Setup
Prerequisites
# Python 3.7+ required
python3 --version
pip --version
Install FastAPI
# Create virtual environment
python3 -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install FastAPI and Uvicorn
pip install fastapi
pip install "uvicorn[standard]"
# Install additional dependencies
pip install python-multipart # For file uploads
pip install python-jose[cryptography] # For JWT
pip install passlib[bcrypt] # For password hashing
pip install sqlalchemy # For database
pip install alembic # For migrations
pip install pydantic[email] # For email validation
Project Structure
fastapi-app/
├── app/
│ ├── __init__.py
│ ├── main.py
│ ├── config.py
│ ├── database.py
│ ├── dependencies.py
│ ├── models/
│ │ ├── __init__.py
│ │ └── user.py
│ ├── schemas/
│ │ ├── __init__.py
│ │ └── user.py
│ ├── routers/
│ │ ├── __init__.py
│ │ ├── users.py
│ │ └── auth.py
│ ├── services/
│ │ └── auth.py
│ └── utils/
│ └── security.py
├── tests/
│ └── test_main.py
├── alembic/
├── .env
├── requirements.txt
└── README.md
Basic Application
Minimal App
from fastapi import FastAPI
app = FastAPI()
@app.get("/")
def read_root():
return {"message": "Hello World"}
@app.get("/items/{item_id}")
def read_item(item_id: int, q: str = None):
return {"item_id": item_id, "q": q}
# Run with: uvicorn main:app --reload
Full Application Setup
app/main.py:
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.routers import users, auth, items
from app.database import engine
from app import models
models.Base.metadata.create_all(bind=engine)
app = FastAPI(
title="My API",
description="A production-ready FastAPI application",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include routers
app.include_router(auth.router, prefix="/auth", tags=["auth"])
app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(items.router, prefix="/items", tags=["items"])
@app.get("/")
async def root():
return {"message": "Welcome to FastAPI"}
@app.get("/health")
async def health_check():
return {"status": "healthy"}
app/config.py:
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
app_name: str = "FastAPI App"
database_url: str = "sqlite:///./test.db"
secret_key: str = "your-secret-key-here"
algorithm: str = "HS256"
access_token_expire_minutes: int = 30
class Config:
env_file = ".env"
settings = Settings()
Path Operations
HTTP Methods
from fastapi import FastAPI, HTTPException, status
from pydantic import BaseModel
from typing import List, Optional
app = FastAPI()
class Item(BaseModel):
name: str
description: Optional[str] = None
price: float
tax: Optional[float] = None
# GET
@app.get("/items")
async def get_items():
return [{"id": 1, "name": "Item 1"}]
# GET with path parameter
@app.get("/items/{item_id}")
async def get_item(item_id: int):
return {"item_id": item_id}
# POST
@app.post("/items", status_code=status.HTTP_201_CREATED)
async def create_item(item: Item):
return item
# PUT
@app.put("/items/{item_id}")
async def update_item(item_id: int, item: Item):
return {"item_id": item_id, **item.dict()}
# PATCH
@app.patch("/items/{item_id}")
async def partial_update_item(item_id: int, item: dict):
return {"item_id": item_id, "updated_fields": item}
# DELETE
@app.delete("/items/{item_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_item(item_id: int):
return None
Query Parameters
from typing import Optional, List
from enum import Enum
class SortOrder(str, Enum):
asc = "asc"
desc = "desc"
@app.get("/items")
async def list_items(
skip: int = 0,
limit: int = 10,
q: Optional[str] = None,
sort: SortOrder = SortOrder.asc,
tags: List[str] = []
):
return {
"skip": skip,
"limit": limit,
"q": q,
"sort": sort,
"tags": tags
}
# Required query parameter
@app.get("/search")
async def search(q: str): # Required
return {"q": q}
Path Parameters
from uuid import UUID
from datetime import date
@app.get("/users/{user_id}")
async def get_user(user_id: int):
return {"user_id": user_id}
@app.get("/orders/{order_id}")
async def get_order(order_id: UUID):
return {"order_id": str(order_id)}
@app.get("/posts/{year}/{month}/{day}")
async def get_posts_by_date(year: int, month: int, day: int):
post_date = date(year, month, day)
return {"date": post_date}
# Path with validation
from fastapi import Path
@app.get("/items/{item_id}")
async def get_item(
item_id: int = Path(..., title="The ID of the item", ge=1)
):
return {"item_id": item_id}
Request and Response Models
Pydantic Models
from pydantic import BaseModel, Field, EmailStr, validator
from typing import Optional, List
from datetime import datetime
class UserBase(BaseModel):
email: EmailStr
username: str = Field(..., min_length=3, max_length=50)
full_name: Optional[str] = None
class UserCreate(UserBase):
password: str = Field(..., min_length=8)
@validator('password')
def password_strength(cls, v):
if not any(char.isdigit() for char in v):
raise ValueError('Password must contain at least one digit')
if not any(char.isupper() for char in v):
raise ValueError('Password must contain at least one uppercase letter')
return v
class UserUpdate(BaseModel):
email: Optional[EmailStr] = None
full_name: Optional[str] = None
class User(UserBase):
id: int
is_active: bool = True
created_at: datetime
class Config:
from_attributes = True
class UserInDB(User):
hashed_password: str
# Product models
class Product(BaseModel):
name: str
description: Optional[str] = None
price: float = Field(..., gt=0, description="Price must be greater than zero")
tax: Optional[float] = 0
tags: List[str] = []
class ProductResponse(Product):
id: int
created_at: datetime
class Config:
from_attributes = True
Request Body
from fastapi import Body
@app.post("/users")
async def create_user(user: UserCreate):
return user
# Multiple body parameters
@app.post("/items")
async def create_item(
item: Item,
user: User,
importance: int = Body(...)
):
return {"item": item, "user": user, "importance": importance}
# Embed single body parameter
@app.post("/items/{item_id}")
async def update_item(
item_id: int,
item: Item = Body(..., embed=True)
):
return {"item_id": item_id, "item": item}
Response Models
from fastapi import Response, status
@app.post("/users", response_model=User, status_code=status.HTTP_201_CREATED)
async def create_user(user: UserCreate):
# Don't return password in response
return user
# Multiple response models
from fastapi.responses import JSONResponse
@app.get("/items/{item_id}")
async def get_item(item_id: int):
if item_id == 0:
return JSONResponse(
status_code=404,
content={"message": "Item not found"}
)
return {"item_id": item_id}
# Response with Union types
from typing import Union
@app.get("/items/{item_id}", response_model=Union[Product, dict])
async def get_item(item_id: int):
if item_id > 0:
return product
return {"message": "No item found"}
Dependency Injection
Basic Dependencies
from fastapi import Depends, HTTPException, status
from typing import Optional
# Simple dependency
async def common_parameters(q: Optional[str] = None, skip: int = 0, limit: int = 100):
return {"q": q, "skip": skip, "limit": limit}
@app.get("/items")
async def read_items(commons: dict = Depends(common_parameters)):
return commons
# Class-based dependency
class CommonQueryParams:
def __init__(self, q: Optional[str] = None, skip: int = 0, limit: int = 100):
self.q = q
self.skip = skip
self.limit = limit
@app.get("/users")
async def read_users(commons: CommonQueryParams = Depends()):
return commons
Database Dependency
app/database.py:
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from app.config import settings
engine = create_engine(
settings.database_url,
connect_args={"check_same_thread": False} if "sqlite" in settings.database_url else {}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
Current User Dependency
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy.orm import Session
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db)
):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm])
user_id: int = payload.get("sub")
if user_id is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise credentials_exception
return user
# Use dependency
@app.get("/users/me", response_model=User)
async def read_users_me(current_user: User = Depends(get_current_user)):
return current_user
Database Integration
SQLAlchemy Models
app/models/user.py:
from sqlalchemy import Boolean, Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from app.database import Base
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
email = Column(String, unique=True, index=True, nullable=False)
username = Column(String, unique=True, index=True, nullable=False)
full_name = Column(String)
hashed_password = Column(String, nullable=False)
is_active = Column(Boolean, default=True)
is_superuser = Column(Boolean, default=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
items = relationship("Item", back_populates="owner")
class Item(Base):
__tablename__ = "items"
id = Column(Integer, primary_key=True, index=True)
title = Column(String, index=True)
description = Column(String)
owner_id = Column(Integer, ForeignKey("users.id"))
created_at = Column(DateTime(timezone=True), server_default=func.now())
owner = relationship("User", back_populates="items")
CRUD Operations
app/services/user.py:
from sqlalchemy.orm import Session
from app.models.user import User
from app.schemas.user import UserCreate, UserUpdate
from app.utils.security import get_password_hash
def get_user(db: Session, user_id: int):
return db.query(User).filter(User.id == user_id).first()
def get_user_by_email(db: Session, email: str):
return db.query(User).filter(User.email == email).first()
def get_users(db: Session, skip: int = 0, limit: int = 100):
return db.query(User).offset(skip).limit(limit).all()
def create_user(db: Session, user: UserCreate):
hashed_password = get_password_hash(user.password)
db_user = User(
email=user.email,
username=user.username,
full_name=user.full_name,
hashed_password=hashed_password
)
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
def update_user(db: Session, user_id: int, user: UserUpdate):
db_user = get_user(db, user_id)
if db_user:
update_data = user.dict(exclude_unset=True)
for key, value in update_data.items():
setattr(db_user, key, value)
db.commit()
db.refresh(db_user)
return db_user
def delete_user(db: Session, user_id: int):
db_user = get_user(db, user_id)
if db_user:
db.delete(db_user)
db.commit()
return db_user
Router with Database
app/routers/users.py:
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from typing import List
from app.database import get_db
from app.schemas.user import User, UserCreate, UserUpdate
from app.services import user as user_service
from app.dependencies import get_current_active_user
router = APIRouter()
@router.get("/", response_model=List[User])
def read_users(
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db)
):
users = user_service.get_users(db, skip=skip, limit=limit)
return users
@router.get("/{user_id}", response_model=User)
def read_user(user_id: int, db: Session = Depends(get_db)):
db_user = user_service.get_user(db, user_id=user_id)
if db_user is None:
raise HTTPException(status_code=404, detail="User not found")
return db_user
@router.post("/", response_model=User, status_code=status.HTTP_201_CREATED)
def create_user(user: UserCreate, db: Session = Depends(get_db)):
db_user = user_service.get_user_by_email(db, email=user.email)
if db_user:
raise HTTPException(status_code=400, detail="Email already registered")
return user_service.create_user(db=db, user=user)
@router.put("/{user_id}", response_model=User)
def update_user(
user_id: int,
user: UserUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
if current_user.id != user_id and not current_user.is_superuser:
raise HTTPException(status_code=403, detail="Not authorized")
db_user = user_service.update_user(db, user_id=user_id, user=user)
if db_user is None:
raise HTTPException(status_code=404, detail="User not found")
return db_user
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_user(
user_id: int,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
):
if current_user.id != user_id and not current_user.is_superuser:
raise HTTPException(status_code=403, detail="Not authorized")
db_user = user_service.delete_user(db, user_id=user_id)
if db_user is None:
raise HTTPException(status_code=404, detail="User not found")
Authentication and Security
Password Hashing
app/utils/security.py:
from passlib.context import CryptContext
from datetime import datetime, timedelta
from jose import jwt
from app.config import settings
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
def create_access_token(data: dict, expires_delta: timedelta = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
return encoded_jwt
JWT Authentication
app/routers/auth.py:
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from datetime import timedelta
from app.database import get_db
from app.schemas.auth import Token
from app.services import user as user_service
from app.utils.security import verify_password, create_access_token
from app.config import settings
router = APIRouter()
@router.post("/token", response_model=Token)
async def login(
form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db)
):
user = user_service.get_user_by_email(db, email=form_data.username)
if not user or not verify_password(form_data.password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
headers={"WWW-Authenticate": "Bearer"},
)
if not user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
access_token_expires = timedelta(minutes=settings.access_token_expire_minutes)
access_token = create_access_token(
data={"sub": str(user.id)},
expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
@router.post("/register", response_model=User, status_code=status.HTTP_201_CREATED)
async def register(user: UserCreate, db: Session = Depends(get_db)):
db_user = user_service.get_user_by_email(db, email=user.email)
if db_user:
raise HTTPException(status_code=400, detail="Email already registered")
return user_service.create_user(db=db, user=user)
API Key Authentication
from fastapi import Security, HTTPException, status
from fastapi.security.api_key import APIKeyHeader
API_KEY = "your-api-key-here"
api_key_header = APIKeyHeader(name="X-API-Key")
async def verify_api_key(api_key: str = Security(api_key_header)):
if api_key != API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API Key"
)
return api_key
@app.get("/secure-data")
async def get_secure_data(api_key: str = Depends(verify_api_key)):
return {"data": "sensitive information"}
Background Tasks
from fastapi import BackgroundTasks
import smtplib
from email.mime.text import MIMEText
def send_email(email: str, subject: str, body: str):
# Email sending logic
print(f"Sending email to {email}: {subject}")
def write_log(message: str):
with open("log.txt", mode="a") as log:
log.write(message + "\n")
@app.post("/send-notification/{email}")
async def send_notification(
email: str,
background_tasks: BackgroundTasks
):
background_tasks.add_task(send_email, email, "Welcome!", "Thanks for signing up")
background_tasks.add_task(write_log, f"Notification sent to {email}")
return {"message": "Notification sent in the background"}
# Multiple background tasks
@app.post("/users")
async def create_user(
user: UserCreate,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db)
):
db_user = user_service.create_user(db, user)
background_tasks.add_task(send_email, user.email, "Welcome", "Thanks for joining!")
background_tasks.add_task(write_log, f"User created: {user.email}")
return db_user
WebSockets
from fastapi import WebSocket, WebSocketDisconnect
from typing import List
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: int):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
await manager.send_personal_message(f"You wrote: {data}", websocket)
await manager.broadcast(f"Client #{client_id} says: {data}")
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(f"Client #{client_id} left the chat")
File Operations
File Upload
from fastapi import File, UploadFile
from typing import List
import shutil
from pathlib import Path
UPLOAD_DIR = Path("uploads")
UPLOAD_DIR.mkdir(exist_ok=True)
@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
file_path = UPLOAD_DIR / file.filename
with file_path.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
return {
"filename": file.filename,
"content_type": file.content_type,
"size": file_path.stat().st_size
}
@app.post("/upload-multiple")
async def upload_multiple_files(files: List[UploadFile] = File(...)):
file_info = []
for file in files:
file_path = UPLOAD_DIR / file.filename
with file_path.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
file_info.append({
"filename": file.filename,
"size": file_path.stat().st_size
})
return {"files": file_info}
File Download
from fastapi.responses import FileResponse, StreamingResponse
import io
@app.get("/download/{filename}")
async def download_file(filename: str):
file_path = UPLOAD_DIR / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail="File not found")
return FileResponse(file_path, filename=filename)
@app.get("/stream")
async def stream_file():
def iterfile():
with open("large_file.txt", mode="rb") as file:
yield from file
return StreamingResponse(iterfile(), media_type="text/plain")
Testing
from fastapi.testclient import TestClient
from app.main import app
from app.database import Base, engine, get_db
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Test database
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
test_engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=test_engine)
def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
# Test functions
def test_read_root():
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Welcome to FastAPI"}
def test_create_user():
Base.metadata.create_all(bind=test_engine)
response = client.post(
"/users",
json={
"email": "test@example.com",
"username": "testuser",
"password": "TestPass123"
}
)
assert response.status_code == 201
data = response.json()
assert data["email"] == "test@example.com"
assert "id" in data
Base.metadata.drop_all(bind=test_engine)
def test_login():
Base.metadata.create_all(bind=test_engine)
# Create user
client.post(
"/auth/register",
json={
"email": "test@example.com",
"username": "testuser",
"password": "TestPass123"
}
)
# Login
response = client.post(
"/auth/token",
data={
"username": "test@example.com",
"password": "TestPass123"
}
)
assert response.status_code == 200
assert "access_token" in response.json()
Base.metadata.drop_all(bind=test_engine)
def test_authenticated_route():
# Get token
response = client.post("/auth/token", data={"username": "test@example.com", "password": "TestPass123"})
token = response.json()["access_token"]
# Access protected route
response = client.get(
"/users/me",
headers={"Authorization": f"Bearer {token}"}
)
assert response.status_code == 200
Best Practices
1. Project Structure
# Use modular structure with routers
app/
├── routers/
│ ├── users.py
│ ├── items.py
│ └── auth.py
├── models/
├── schemas/
├── services/
└── utils/
2. Environment Variables
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
database_url: str
secret_key: str
class Config:
env_file = ".env"
3. Error Handling
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
@app.exception_handler(ValueError)
async def value_error_handler(request: Request, exc: ValueError):
return JSONResponse(
status_code=400,
content={"message": str(exc)}
)
4. Async Operations
import asyncio
import httpx
@app.get("/external-api")
async def call_external_api():
async with httpx.AsyncClient() as client:
response = await client.get("https://api.example.com/data")
return response.json()
5. Middleware
import time
from fastapi import Request
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
Production Deployment
Requirements
requirements.txt:
fastapi==0.104.1
uvicorn[standard]==0.24.0
sqlalchemy==2.0.23
alembic==1.12.1
pydantic[email]==2.5.0
python-jose[cryptography]==3.3.0
passlib[bcrypt]==1.7.4
python-multipart==0.0.6
python-dotenv==1.0.0
Uvicorn with Workers
# Development
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
# Production
uvicorn app.main:app --host 0.0.0.0 --port 8000 --workers 4
# With Gunicorn
gunicorn app.main:app --workers 4 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000
Docker
Dockerfile:
FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY ./app ./app
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
docker-compose.yml:
version: '3.8'
services:
api:
build: .
ports:
- "8000:8000"
environment:
- DATABASE_URL=postgresql://user:password@db:5432/mydb
- SECRET_KEY=${SECRET_KEY}
depends_on:
- db
db:
image: postgres:15-alpine
environment:
POSTGRES_PASSWORD: password
POSTGRES_DB: mydb
volumes:
- postgres_data:/var/lib/postgresql/data
volumes:
postgres_data:
Resources
Official Documentation:
Learning Resources:
Community:
Related Tools:
Web APIs
Overview
Web APIs are interfaces provided by browsers that allow JavaScript to interact with browser features, device hardware, and web platform capabilities. These APIs enable rich, interactive web applications without requiring plugins or native code.
Storage APIs
Web Storage (localStorage & sessionStorage)
Simple key-value storage for strings:
// LocalStorage (persists across sessions)
// Store data
localStorage.setItem('username', 'john_doe');
localStorage.setItem('theme', 'dark');
// Retrieve data
const username = localStorage.getItem('username');
console.log(username); // 'john_doe'
// Store objects (must serialize)
const user = { name: 'John', age: 30 };
localStorage.setItem('user', JSON.stringify(user));
// Retrieve objects (must parse)
const storedUser = JSON.parse(localStorage.getItem('user'));
// Remove item
localStorage.removeItem('username');
// Clear all
localStorage.clear();
// Get all keys
for (let i = 0; i < localStorage.length; i++) {
const key = localStorage.key(i);
console.log(key, localStorage.getItem(key));
}
// SessionStorage (cleared when tab closes)
sessionStorage.setItem('sessionId', '12345');
sessionStorage.getItem('sessionId');
// Storage event (listen for changes in other tabs)
window.addEventListener('storage', (e) => {
console.log('Storage changed:');
console.log('Key:', e.key);
console.log('Old value:', e.oldValue);
console.log('New value:', e.newValue);
console.log('URL:', e.url);
});
// Limitations:
// - 5-10 MB limit (varies by browser)
// - Strings only (must serialize objects)
// - Synchronous (blocks main thread)
// - No expiration mechanism
IndexedDB
Powerful client-side database for structured data:
// Open database
const request = indexedDB.open('MyDatabase', 1);
// Create object stores (like tables)
request.onupgradeneeded = (event) => {
const db = event.target.result;
// Create object store
const objectStore = db.createObjectStore('users', {
keyPath: 'id',
autoIncrement: true
});
// Create indexes
objectStore.createIndex('email', 'email', { unique: true });
objectStore.createIndex('name', 'name', { unique: false });
console.log('Database upgraded');
};
request.onsuccess = (event) => {
const db = event.target.result;
console.log('Database opened successfully');
// Add data
const transaction = db.transaction(['users'], 'readwrite');
const objectStore = transaction.objectStore('users');
const user = {
name: 'John Doe',
email: 'john@example.com',
age: 30
};
const addRequest = objectStore.add(user);
addRequest.onsuccess = () => {
console.log('User added with ID:', addRequest.result);
};
// Get data by key
const getRequest = objectStore.get(1);
getRequest.onsuccess = () => {
console.log('User:', getRequest.result);
};
// Get by index
const index = objectStore.index('email');
const emailRequest = index.get('john@example.com');
emailRequest.onsuccess = () => {
console.log('User by email:', emailRequest.result);
};
// Update data
const updateRequest = objectStore.put({
id: 1,
name: 'John Smith',
email: 'john@example.com',
age: 31
});
// Delete data
const deleteRequest = objectStore.delete(1);
// Get all data
const getAllRequest = objectStore.getAll();
getAllRequest.onsuccess = () => {
console.log('All users:', getAllRequest.result);
};
// Cursor (iterate over records)
const cursorRequest = objectStore.openCursor();
cursorRequest.onsuccess = (event) => {
const cursor = event.target.result;
if (cursor) {
console.log('Record:', cursor.value);
cursor.continue(); // Move to next record
}
};
};
request.onerror = (event) => {
console.error('Database error:', event.target.error);
};
// Promised-based wrapper (easier to use)
class IndexedDBHelper {
constructor(dbName, version) {
this.dbName = dbName;
this.version = version;
this.db = null;
}
async open(upgrade) {
return new Promise((resolve, reject) => {
const request = indexedDB.open(this.dbName, this.version);
request.onupgradeneeded = (event) => {
if (upgrade) {
upgrade(event.target.result);
}
};
request.onsuccess = (event) => {
this.db = event.target.result;
resolve(this.db);
};
request.onerror = () => reject(request.error);
});
}
async add(storeName, data) {
return new Promise((resolve, reject) => {
const transaction = this.db.transaction([storeName], 'readwrite');
const store = transaction.objectStore(storeName);
const request = store.add(data);
request.onsuccess = () => resolve(request.result);
request.onerror = () => reject(request.error);
});
}
async get(storeName, key) {
return new Promise((resolve, reject) => {
const transaction = this.db.transaction([storeName], 'readonly');
const store = transaction.objectStore(storeName);
const request = store.get(key);
request.onsuccess = () => resolve(request.result);
request.onerror = () => reject(request.error);
});
}
async getAll(storeName) {
return new Promise((resolve, reject) => {
const transaction = this.db.transaction([storeName], 'readonly');
const store = transaction.objectStore(storeName);
const request = store.getAll();
request.onsuccess = () => resolve(request.result);
request.onerror = () => reject(request.error);
});
}
async update(storeName, data) {
return new Promise((resolve, reject) => {
const transaction = this.db.transaction([storeName], 'readwrite');
const store = transaction.objectStore(storeName);
const request = store.put(data);
request.onsuccess = () => resolve(request.result);
request.onerror = () => reject(request.error);
});
}
async delete(storeName, key) {
return new Promise((resolve, reject) => {
const transaction = this.db.transaction([storeName], 'readwrite');
const store = transaction.objectStore(storeName);
const request = store.delete(key);
request.onsuccess = () => resolve();
request.onerror = () => reject(request.error);
});
}
}
// Usage
const db = new IndexedDBHelper('MyApp', 1);
await db.open((database) => {
const store = database.createObjectStore('users', { keyPath: 'id', autoIncrement: true });
store.createIndex('email', 'email', { unique: true });
});
await db.add('users', { name: 'John', email: 'john@example.com' });
const users = await db.getAll('users');
console.log(users);
Cache API
Store network requests and responses:
// Open cache
const cache = await caches.open('my-cache-v1');
// Add to cache
await cache.add('/api/data');
await cache.addAll([
'/styles.css',
'/script.js',
'/image.png'
]);
// Put custom response in cache
const response = new Response(JSON.stringify({ data: 'cached' }), {
headers: { 'Content-Type': 'application/json' }
});
await cache.put('/api/custom', response);
// Get from cache
const cachedResponse = await cache.match('/api/data');
if (cachedResponse) {
const data = await cachedResponse.json();
console.log('Cached data:', data);
}
// Delete from cache
await cache.delete('/api/data');
// Get all keys
const keys = await cache.keys();
console.log('Cached URLs:', keys.map(req => req.url));
// Delete old caches
const cacheWhitelist = ['my-cache-v2'];
const cacheNames = await caches.keys();
await Promise.all(
cacheNames.map(cacheName => {
if (!cacheWhitelist.includes(cacheName)) {
return caches.delete(cacheName);
}
})
);
// Cache-first strategy
async function fetchWithCache(url) {
const cachedResponse = await caches.match(url);
if (cachedResponse) {
return cachedResponse;
}
const response = await fetch(url);
const cache = await caches.open('my-cache-v1');
cache.put(url, response.clone());
return response;
}
// Network-first strategy
async function fetchNetworkFirst(url) {
try {
const response = await fetch(url);
const cache = await caches.open('my-cache-v1');
cache.put(url, response.clone());
return response;
} catch (error) {
const cachedResponse = await caches.match(url);
if (cachedResponse) {
return cachedResponse;
}
throw error;
}
}
Web Workers
Worker (Background Threads)
Run JavaScript in background threads:
// main.js - Main thread
const worker = new Worker('worker.js');
// Send message to worker
worker.postMessage({ type: 'calculate', data: [1, 2, 3, 4, 5] });
// Receive message from worker
worker.onmessage = (event) => {
console.log('Result from worker:', event.data);
};
worker.onerror = (error) => {
console.error('Worker error:', error.message);
};
// Terminate worker
worker.terminate();
// ============================================
// worker.js - Worker thread
self.onmessage = (event) => {
const { type, data } = event.data;
if (type === 'calculate') {
// Perform heavy computation
const result = data.reduce((sum, num) => sum + num, 0);
// Send result back to main thread
self.postMessage(result);
}
};
// Worker can't access:
// - DOM
// - window object
// - document object
// - parent object
// Worker can access:
// - navigator
// - location (read-only)
// - XMLHttpRequest / fetch
// - setTimeout / setInterval
// - IndexedDB
// - Cache API
// ============================================
// Advanced: Transferable objects (zero-copy)
const buffer = new ArrayBuffer(1024 * 1024); // 1 MB
worker.postMessage({ buffer }, [buffer]); // Transfer ownership
// buffer is now unusable in main thread
// ============================================
// Inline worker (no separate file)
const code = `
self.onmessage = (e) => {
self.postMessage(e.data * 2);
};
`;
const blob = new Blob([code], { type: 'application/javascript' });
const workerUrl = URL.createObjectURL(blob);
const inlineWorker = new Worker(workerUrl);
inlineWorker.postMessage(5);
inlineWorker.onmessage = (e) => {
console.log('Result:', e.data); // 10
};
// ============================================
// Shared Worker (shared across tabs)
const sharedWorker = new SharedWorker('shared-worker.js');
sharedWorker.port.postMessage('hello');
sharedWorker.port.onmessage = (event) => {
console.log('From shared worker:', event.data);
};
// shared-worker.js
const connections = [];
self.onconnect = (event) => {
const port = event.ports[0];
connections.push(port);
port.onmessage = (e) => {
// Broadcast to all connections
connections.forEach(conn => {
conn.postMessage(`Broadcast: ${e.data}`);
});
};
};
Service Worker
Powerful worker for offline capabilities and caching:
// Register service worker
if ('serviceWorker' in navigator) {
navigator.serviceWorker.register('/service-worker.js')
.then(registration => {
console.log('Service Worker registered:', registration.scope);
// Check for updates
registration.addEventListener('updatefound', () => {
const newWorker = registration.installing;
console.log('New service worker installing');
newWorker.addEventListener('statechange', () => {
if (newWorker.state === 'installed') {
if (navigator.serviceWorker.controller) {
console.log('New version available, please refresh');
} else {
console.log('Content cached for offline use');
}
}
});
});
})
.catch(error => {
console.error('Service Worker registration failed:', error);
});
// Listen for messages from service worker
navigator.serviceWorker.addEventListener('message', (event) => {
console.log('Message from SW:', event.data);
});
}
// ============================================
// service-worker.js
const CACHE_NAME = 'my-app-v1';
const urlsToCache = [
'/',
'/styles.css',
'/script.js',
'/offline.html'
];
// Install event - cache resources
self.addEventListener('install', (event) => {
console.log('Service Worker installing');
event.waitUntil(
caches.open(CACHE_NAME)
.then(cache => {
console.log('Caching resources');
return cache.addAll(urlsToCache);
})
.then(() => self.skipWaiting()) // Activate immediately
);
});
// Activate event - clean up old caches
self.addEventListener('activate', (event) => {
console.log('Service Worker activating');
event.waitUntil(
caches.keys().then(cacheNames => {
return Promise.all(
cacheNames.map(cacheName => {
if (cacheName !== CACHE_NAME) {
console.log('Deleting old cache:', cacheName);
return caches.delete(cacheName);
}
})
);
}).then(() => self.clients.claim()) // Take control immediately
);
});
// Fetch event - serve from cache
self.addEventListener('fetch', (event) => {
event.respondWith(
caches.match(event.request)
.then(response => {
// Return cached version or fetch from network
return response || fetch(event.request)
.then(fetchResponse => {
// Cache new resources
return caches.open(CACHE_NAME)
.then(cache => {
cache.put(event.request, fetchResponse.clone());
return fetchResponse;
});
})
.catch(() => {
// Return offline page if fetch fails
return caches.match('/offline.html');
});
})
);
});
// Push notification event
self.addEventListener('push', (event) => {
const data = event.data ? event.data.json() : {};
event.waitUntil(
self.registration.showNotification(data.title, {
body: data.body,
icon: '/icon.png',
badge: '/badge.png',
data: data.url
})
);
});
// Notification click event
self.addEventListener('notificationclick', (event) => {
event.notification.close();
event.waitUntil(
clients.openWindow(event.notification.data)
);
});
// Sync event (background sync)
self.addEventListener('sync', (event) => {
if (event.tag === 'sync-messages') {
event.waitUntil(syncMessages());
}
});
async function syncMessages() {
// Sync pending messages
const messages = await getUnsyncedMessages();
await Promise.all(
messages.map(msg => fetch('/api/messages', {
method: 'POST',
body: JSON.stringify(msg)
}))
);
}
// Message from client
self.addEventListener('message', (event) => {
if (event.data.type === 'SKIP_WAITING') {
self.skipWaiting();
}
});
Notification API
Display system notifications:
// Request permission
async function requestNotificationPermission() {
const permission = await Notification.requestPermission();
if (permission === 'granted') {
console.log('Notification permission granted');
} else if (permission === 'denied') {
console.log('Notification permission denied');
}
}
// Check current permission
console.log('Permission:', Notification.permission);
// 'default', 'granted', or 'denied'
// Show notification (simple)
if (Notification.permission === 'granted') {
new Notification('Hello!', {
body: 'This is a notification',
icon: '/icon.png',
badge: '/badge.png'
});
}
// Show notification (advanced)
const notification = new Notification('New Message', {
body: 'You have 3 new messages',
icon: '/icon.png',
badge: '/badge.png',
image: '/banner.png',
tag: 'message-notification', // Replaces notifications with same tag
renotify: true, // Notify even if same tag exists
requireInteraction: false, // Auto-dismiss
silent: false,
vibrate: [200, 100, 200], // Vibration pattern
timestamp: Date.now(),
actions: [
{ action: 'view', title: 'View', icon: '/view.png' },
{ action: 'dismiss', title: 'Dismiss', icon: '/dismiss.png' }
],
data: { url: '/messages' } // Custom data
});
// Event handlers
notification.onclick = (event) => {
console.log('Notification clicked');
window.focus();
notification.close();
};
notification.onclose = () => {
console.log('Notification closed');
};
notification.onerror = (error) => {
console.error('Notification error:', error);
};
notification.onshow = () => {
console.log('Notification shown');
};
// Close notification
setTimeout(() => {
notification.close();
}, 5000);
// Service Worker notifications (recommended)
if ('serviceWorker' in navigator) {
navigator.serviceWorker.ready.then(registration => {
registration.showNotification('Title', {
body: 'Body text',
icon: '/icon.png',
actions: [
{ action: 'yes', title: 'Yes' },
{ action: 'no', title: 'No' }
]
});
});
}
Geolocation API
Access device location:
// Check if available
if ('geolocation' in navigator) {
console.log('Geolocation is available');
}
// Get current position (one-time)
navigator.geolocation.getCurrentPosition(
// Success callback
(position) => {
console.log('Latitude:', position.coords.latitude);
console.log('Longitude:', position.coords.longitude);
console.log('Accuracy:', position.coords.accuracy, 'meters');
console.log('Altitude:', position.coords.altitude);
console.log('Altitude accuracy:', position.coords.altitudeAccuracy);
console.log('Heading:', position.coords.heading); // Direction of travel
console.log('Speed:', position.coords.speed); // meters/second
console.log('Timestamp:', position.timestamp);
},
// Error callback
(error) => {
switch (error.code) {
case error.PERMISSION_DENIED:
console.error('User denied geolocation');
break;
case error.POSITION_UNAVAILABLE:
console.error('Position unavailable');
break;
case error.TIMEOUT:
console.error('Request timeout');
break;
}
},
// Options
{
enableHighAccuracy: true, // Use GPS (more battery)
timeout: 5000, // 5 seconds
maximumAge: 0 // Don't use cached position
}
);
// Watch position (continuous updates)
const watchId = navigator.geolocation.watchPosition(
(position) => {
console.log('Position updated:', position.coords);
updateMapMarker(position.coords.latitude, position.coords.longitude);
},
(error) => {
console.error('Watch error:', error);
},
{
enableHighAccuracy: true,
timeout: 10000,
maximumAge: 5000 // Use cached position if < 5 seconds old
}
);
// Stop watching
navigator.geolocation.clearWatch(watchId);
// Promised-based wrapper
function getPosition(options = {}) {
return new Promise((resolve, reject) => {
navigator.geolocation.getCurrentPosition(resolve, reject, options);
});
}
// Usage
try {
const position = await getPosition({ enableHighAccuracy: true });
console.log('Position:', position.coords);
} catch (error) {
console.error('Error getting position:', error);
}
File API
Read and manipulate files:
// File input
const input = document.getElementById('fileInput');
input.addEventListener('change', async (event) => {
const files = event.target.files;
for (const file of files) {
console.log('Name:', file.name);
console.log('Size:', file.size, 'bytes');
console.log('Type:', file.type);
console.log('Last modified:', new Date(file.lastModified));
// Read as text
const text = await file.text();
console.log('Content:', text);
// Read as ArrayBuffer
const buffer = await file.arrayBuffer();
console.log('Buffer:', new Uint8Array(buffer));
// Read as Data URL (base64)
const reader = new FileReader();
reader.onload = (e) => {
console.log('Data URL:', e.target.result);
// Can use as img src
document.getElementById('preview').src = e.target.result;
};
reader.readAsDataURL(file);
// Read as text with FileReader
reader.onload = (e) => {
console.log('Text:', e.target.result);
};
reader.readAsText(file);
// Read as ArrayBuffer with FileReader
reader.onload = (e) => {
const buffer = e.target.result;
console.log('Buffer:', new Uint8Array(buffer));
};
reader.readAsArrayBuffer(file);
// Progress event
reader.onprogress = (e) => {
if (e.lengthComputable) {
const percent = (e.loaded / e.total) * 100;
console.log('Progress:', percent.toFixed(2) + '%');
}
};
}
});
// Drag and drop
const dropZone = document.getElementById('dropZone');
dropZone.addEventListener('dragover', (e) => {
e.preventDefault();
dropZone.classList.add('drag-over');
});
dropZone.addEventListener('dragleave', () => {
dropZone.classList.remove('drag-over');
});
dropZone.addEventListener('drop', async (e) => {
e.preventDefault();
dropZone.classList.remove('drag-over');
const files = e.dataTransfer.files;
for (const file of files) {
console.log('Dropped file:', file.name);
}
});
// Create Blob
const blob = new Blob(['Hello, World!'], { type: 'text/plain' });
console.log('Blob size:', blob.size);
console.log('Blob type:', blob.type);
// Read Blob
const text = await blob.text();
console.log('Blob text:', text);
// Blob to URL
const url = URL.createObjectURL(blob);
console.log('Blob URL:', url);
// Don't forget to revoke
URL.revokeObjectURL(url);
// Download file
function downloadFile(content, filename, type) {
const blob = new Blob([content], { type });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = filename;
a.click();
URL.revokeObjectURL(url);
}
downloadFile('Hello, World!', 'hello.txt', 'text/plain');
// File from Blob
const file = new File([blob], 'example.txt', {
type: 'text/plain',
lastModified: Date.now()
});
Clipboard API
Read and write clipboard:
// Write text to clipboard
async function copyText(text) {
try {
await navigator.clipboard.writeText(text);
console.log('Text copied to clipboard');
} catch (error) {
console.error('Failed to copy:', error);
}
}
copyText('Hello, clipboard!');
// Read text from clipboard
async function pasteText() {
try {
const text = await navigator.clipboard.readText();
console.log('Pasted text:', text);
return text;
} catch (error) {
console.error('Failed to read clipboard:', error);
}
}
// Write images/rich content
async function copyImage(blob) {
try {
const item = new ClipboardItem({ 'image/png': blob });
await navigator.clipboard.write([item]);
console.log('Image copied');
} catch (error) {
console.error('Failed to copy image:', error);
}
}
// Read images/rich content
async function pasteImage() {
try {
const items = await navigator.clipboard.read();
for (const item of items) {
for (const type of item.types) {
const blob = await item.getType(type);
if (type.startsWith('image/')) {
const url = URL.createObjectURL(blob);
const img = document.createElement('img');
img.src = url;
document.body.appendChild(img);
}
}
}
} catch (error) {
console.error('Failed to paste:', error);
}
}
// Copy button example
document.getElementById('copyBtn').addEventListener('click', async () => {
const text = document.getElementById('text').textContent;
await copyText(text);
alert('Copied!');
});
// Legacy approach (fallback)
function copyTextLegacy(text) {
const textarea = document.createElement('textarea');
textarea.value = text;
textarea.style.position = 'fixed';
textarea.style.opacity = '0';
document.body.appendChild(textarea);
textarea.select();
document.execCommand('copy');
document.body.removeChild(textarea);
}
Intersection Observer API
Detect element visibility:
// Create observer
const observer = new IntersectionObserver(
(entries, observer) => {
entries.forEach(entry => {
if (entry.isIntersecting) {
console.log('Element is visible:', entry.target);
// Lazy load image
if (entry.target.tagName === 'IMG') {
entry.target.src = entry.target.dataset.src;
observer.unobserve(entry.target); // Stop observing
}
// Animation on scroll
entry.target.classList.add('animate-in');
} else {
console.log('Element is hidden:', entry.target);
}
});
},
{
root: null, // viewport
rootMargin: '0px', // margin around root
threshold: 0.5 // 50% visible
}
);
// Observe elements
const images = document.querySelectorAll('img[data-src]');
images.forEach(img => observer.observe(img));
// Multiple thresholds
const detailedObserver = new IntersectionObserver(
(entries) => {
entries.forEach(entry => {
console.log('Visibility:', entry.intersectionRatio);
// 0 = not visible, 1 = fully visible
});
},
{
threshold: [0, 0.25, 0.5, 0.75, 1.0]
}
);
// Infinite scroll example
const loadMore = document.getElementById('loadMore');
const infiniteObserver = new IntersectionObserver(
(entries) => {
if (entries[0].isIntersecting) {
console.log('Load more items');
loadMoreItems().then(items => {
appendItems(items);
});
}
},
{ threshold: 1.0 }
);
infiniteObserver.observe(loadMore);
// Unobserve element
observer.unobserve(element);
// Disconnect observer
observer.disconnect();
Mutation Observer API
Watch for DOM changes:
// Create observer
const mutationObserver = new MutationObserver((mutations) => {
mutations.forEach(mutation => {
console.log('Type:', mutation.type);
if (mutation.type === 'childList') {
console.log('Children changed');
console.log('Added:', mutation.addedNodes);
console.log('Removed:', mutation.removedNodes);
}
if (mutation.type === 'attributes') {
console.log('Attribute changed:', mutation.attributeName);
console.log('Old value:', mutation.oldValue);
}
if (mutation.type === 'characterData') {
console.log('Text content changed');
console.log('Old value:', mutation.oldValue);
}
});
});
// Observe element
const targetNode = document.getElementById('target');
mutationObserver.observe(targetNode, {
childList: true, // Watch for child additions/removals
attributes: true, // Watch for attribute changes
characterData: true, // Watch for text content changes
subtree: true, // Watch descendants too
attributeOldValue: true, // Record old attribute value
characterDataOldValue: true, // Record old text value
attributeFilter: ['class', 'style'] // Only watch specific attributes
});
// Disconnect observer
mutationObserver.disconnect();
// Example: Watch for dynamically added elements
const bodyObserver = new MutationObserver((mutations) => {
mutations.forEach(mutation => {
mutation.addedNodes.forEach(node => {
if (node.classList && node.classList.contains('dynamic-content')) {
console.log('Dynamic content added:', node);
initializeDynamicContent(node);
}
});
});
});
bodyObserver.observe(document.body, {
childList: true,
subtree: true
});
Resize Observer API
Detect element size changes:
// Create observer
const resizeObserver = new ResizeObserver((entries) => {
entries.forEach(entry => {
console.log('Element:', entry.target);
console.log('Content box:', entry.contentBoxSize);
console.log('Border box:', entry.borderBoxSize);
console.log('Device pixel box:', entry.devicePixelContentBoxSize);
const width = entry.contentRect.width;
const height = entry.contentRect.height;
console.log('Size:', width, 'x', height);
// Responsive behavior
if (width < 600) {
entry.target.classList.add('mobile');
} else {
entry.target.classList.remove('mobile');
}
});
});
// Observe element
const element = document.getElementById('resizable');
resizeObserver.observe(element);
// Observe multiple elements
const elements = document.querySelectorAll('.resizable');
elements.forEach(el => resizeObserver.observe(el));
// Unobserve
resizeObserver.unobserve(element);
// Disconnect
resizeObserver.disconnect();
// Example: Canvas responsive rendering
const canvas = document.getElementById('canvas');
const canvasObserver = new ResizeObserver((entries) => {
const entry = entries[0];
const width = entry.contentRect.width;
const height = entry.contentRect.height;
// Update canvas size
canvas.width = width * devicePixelRatio;
canvas.height = height * devicePixelRatio;
// Re-render
renderCanvas();
});
canvasObserver.observe(canvas);
Page Visibility API
Detect when page is visible:
// Check current visibility
console.log('Hidden:', document.hidden);
console.log('Visibility state:', document.visibilityState);
// 'visible', 'hidden', 'prerender'
// Listen for visibility changes
document.addEventListener('visibilitychange', () => {
if (document.hidden) {
console.log('Page is hidden');
// Pause video
video.pause();
// Stop animations
stopAnimations();
// Reduce network activity
clearInterval(pollingInterval);
} else {
console.log('Page is visible');
// Resume video
video.play();
// Resume animations
startAnimations();
// Resume polling
startPolling();
}
});
// Example: Pause game when tab is hidden
document.addEventListener('visibilitychange', () => {
if (document.hidden) {
game.pause();
} else {
game.resume();
}
});
// Example: Analytics
let startTime = Date.now();
document.addEventListener('visibilitychange', () => {
if (document.hidden) {
const visibleTime = Date.now() - startTime;
analytics.track('time-visible', visibleTime);
} else {
startTime = Date.now();
}
});
Broadcast Channel API
Communicate between tabs/windows:
// Create channel
const channel = new BroadcastChannel('my-channel');
// Send message
channel.postMessage('Hello from tab 1');
channel.postMessage({ type: 'update', data: { count: 5 } });
// Receive messages
channel.onmessage = (event) => {
console.log('Received message:', event.data);
if (event.data.type === 'update') {
updateUI(event.data.data);
}
};
channel.onerror = (error) => {
console.error('Channel error:', error);
};
// Close channel
channel.close();
// Example: Sync state across tabs
const stateChannel = new BroadcastChannel('app-state');
// Tab 1: Update state
function updateState(newState) {
state = newState;
localStorage.setItem('state', JSON.stringify(state));
stateChannel.postMessage({ type: 'state-update', state });
}
// All tabs: Listen for updates
stateChannel.onmessage = (event) => {
if (event.data.type === 'state-update') {
state = event.data.state;
renderUI();
}
};
// Example: Logout all tabs
const authChannel = new BroadcastChannel('auth');
// Tab with logout button
function logout() {
clearAuthToken();
authChannel.postMessage({ type: 'logout' });
redirectToLogin();
}
// All tabs
authChannel.onmessage = (event) => {
if (event.data.type === 'logout') {
clearAuthToken();
redirectToLogin();
}
};
History API
Manipulate browser history:
// Push new state
history.pushState(
{ page: 1 }, // State object
'Title', // Title (ignored by most browsers)
'/page/1' // URL
);
// Replace current state
history.replaceState({ page: 2 }, 'Title', '/page/2');
// Go back
history.back();
// Go forward
history.forward();
// Go to specific point
history.go(-2); // Go back 2 pages
history.go(1); // Go forward 1 page
// Listen for state changes
window.addEventListener('popstate', (event) => {
console.log('State:', event.state);
console.log('URL:', location.pathname);
// Restore page state
if (event.state && event.state.page) {
loadPage(event.state.page);
}
});
// Get current state
console.log('Current state:', history.state);
// Length of history
console.log('History length:', history.length);
// Example: Single Page App navigation
function navigateTo(url, state = {}) {
history.pushState(state, '', url);
loadContent(url);
}
document.querySelectorAll('a[data-link]').forEach(link => {
link.addEventListener('click', (e) => {
e.preventDefault();
navigateTo(link.href);
});
});
window.addEventListener('popstate', () => {
loadContent(location.pathname);
});
Performance API
Measure performance:
// Mark time points
performance.mark('start-task');
// Do some work
await doSomethingExpensive();
performance.mark('end-task');
// Measure duration
performance.measure('task-duration', 'start-task', 'end-task');
// Get measurements
const measures = performance.getEntriesByName('task-duration');
console.log('Duration:', measures[0].duration, 'ms');
// Navigation timing
const navTiming = performance.getEntriesByType('navigation')[0];
console.log('DNS lookup:', navTiming.domainLookupEnd - navTiming.domainLookupStart);
console.log('TCP connect:', navTiming.connectEnd - navTiming.connectStart);
console.log('Request time:', navTiming.responseEnd - navTiming.requestStart);
console.log('DOM load:', navTiming.domContentLoadedEventEnd - navTiming.domContentLoadedEventStart);
console.log('Page load:', navTiming.loadEventEnd - navTiming.loadEventStart);
// Resource timing
const resources = performance.getEntriesByType('resource');
resources.forEach(resource => {
console.log('Resource:', resource.name);
console.log('Duration:', resource.duration);
console.log('Size:', resource.transferSize);
});
// Paint timing
const paintTiming = performance.getEntriesByType('paint');
paintTiming.forEach(entry => {
console.log(`${entry.name}:`, entry.startTime);
});
// first-paint, first-contentful-paint
// Clear marks and measures
performance.clearMarks();
performance.clearMeasures();
// Observer for performance entries
const perfObserver = new PerformanceObserver((list) => {
list.getEntries().forEach(entry => {
console.log('Entry:', entry.name, entry.duration);
});
});
perfObserver.observe({ entryTypes: ['measure', 'navigation', 'resource'] });
// Memory usage (Chrome only)
if (performance.memory) {
console.log('Used heap:', performance.memory.usedJSHeapSize);
console.log('Total heap:', performance.memory.totalJSHeapSize);
console.log('Heap limit:', performance.memory.jsHeapSizeLimit);
}
// Current time (high-resolution)
const start = performance.now();
// Do work
const end = performance.now();
console.log('Elapsed:', end - start, 'ms');
Battery Status API
Get battery information:
if ('getBattery' in navigator) {
const battery = await navigator.getBattery();
console.log('Charging:', battery.charging);
console.log('Level:', battery.level * 100 + '%');
console.log('Charging time:', battery.chargingTime, 'seconds');
console.log('Discharging time:', battery.dischargingTime, 'seconds');
// Listen for changes
battery.addEventListener('chargingchange', () => {
console.log('Charging:', battery.charging);
});
battery.addEventListener('levelchange', () => {
console.log('Battery level:', battery.level * 100 + '%');
if (battery.level < 0.2 && !battery.charging) {
alert('Low battery! Please charge your device.');
}
});
battery.addEventListener('chargingtimechange', () => {
console.log('Charging time:', battery.chargingTime);
});
battery.addEventListener('dischargingtimechange', () => {
console.log('Discharging time:', battery.dischargingTime);
});
// Adaptive features based on battery
if (battery.level < 0.2 && !battery.charging) {
// Reduce animations, polling, etc.
enablePowerSavingMode();
}
}
Web Share API
Share content from web app:
// Check if supported
if (navigator.share) {
console.log('Web Share API supported');
}
// Share text
async function shareText() {
try {
await navigator.share({
title: 'Check this out!',
text: 'This is amazing content',
url: 'https://example.com'
});
console.log('Shared successfully');
} catch (error) {
console.error('Error sharing:', error);
}
}
// Share files
async function shareFiles(files) {
if (navigator.canShare && navigator.canShare({ files })) {
try {
await navigator.share({
files: files,
title: 'Shared files',
text: 'Check out these files'
});
console.log('Files shared successfully');
} catch (error) {
console.error('Error sharing files:', error);
}
} else {
console.log('File sharing not supported');
}
}
// Example: Share button
document.getElementById('shareBtn').addEventListener('click', async () => {
if (navigator.share) {
await shareText();
} else {
// Fallback: Copy link
await navigator.clipboard.writeText(window.location.href);
alert('Link copied to clipboard');
}
});
// Example: Share image
const canvas = document.getElementById('canvas');
canvas.toBlob(async (blob) => {
const file = new File([blob], 'image.png', { type: 'image/png' });
await shareFiles([file]);
});
Browser Support and Feature Detection
Always check for API availability:
// Feature detection
const features = {
serviceWorker: 'serviceWorker' in navigator,
pushNotifications: 'PushManager' in window,
notifications: 'Notification' in window,
geolocation: 'geolocation' in navigator,
webWorker: typeof Worker !== 'undefined',
indexedDB: 'indexedDB' in window,
webRTC: 'RTCPeerConnection' in window,
webGL: (() => {
const canvas = document.createElement('canvas');
return !!(canvas.getContext('webgl') || canvas.getContext('experimental-webgl'));
})(),
mediaDevices: 'mediaDevices' in navigator,
clipboard: 'clipboard' in navigator,
share: 'share' in navigator,
battery: 'getBattery' in navigator
};
console.table(features);
// Polyfill loading
if (!window.IntersectionObserver) {
await import('intersection-observer');
}
// Progressive enhancement
if ('serviceWorker' in navigator) {
// Enable offline support
registerServiceWorker();
} else {
// Gracefully degrade
console.log('Service Worker not supported');
}
Best Practices
// 1. Always check feature support
if ('geolocation' in navigator) {
// Use geolocation
}
// 2. Handle errors gracefully
try {
await navigator.clipboard.writeText('text');
} catch (error) {
// Fallback
fallbackCopyMethod('text');
}
// 3. Request permissions appropriately
// Don't request permission immediately on page load
document.getElementById('enableNotifications').addEventListener('click', async () => {
await Notification.requestPermission();
});
// 4. Clean up resources
const observer = new IntersectionObserver(callback);
// When done:
observer.disconnect();
const worker = new Worker('worker.js');
// When done:
worker.terminate();
// 5. Use Promises/async-await for better readability
// Instead of callbacks
async function loadData() {
const data = await fetch('/api/data').then(r => r.json());
return data;
}
// 6. Respect user privacy
// Check permission status before requesting
const status = await navigator.permissions.query({ name: 'geolocation' });
if (status.state === 'granted') {
// Already have permission
}
// 7. Optimize performance
// Debounce expensive operations
function debounce(func, wait) {
let timeout;
return function (...args) {
clearTimeout(timeout);
timeout = setTimeout(() => func.apply(this, args), wait);
};
}
window.addEventListener('resize', debounce(() => {
console.log('Resized');
}, 250));
Further Resources
Documentation
- MDN Web APIs
- Can I Use - Browser support tables
- Web.dev - Modern web development guides
Specifications
Tools
- Lighthouse - Performance auditing
- Workbox - Service Worker library
Libraries
- Dexie.js - IndexedDB wrapper
- localForage - Unified storage API
- Comlink - Web Worker RPC
REST APIs
Overview
REST (Representational State Transfer) is an architectural style for building web services using HTTP.
Core Principles
- Client-Server: Separation of concerns
- Stateless: Each request contains all info
- Uniform Interface: Consistent API design
- Cacheable: Responses can be cached
- Layered: Client unaware of layers
HTTP Methods
| Method | Purpose | Idempotent |
|---|---|---|
| GET | Retrieve resource | ✓ |
| POST | Create resource | ✗ |
| PUT | Replace resource | ✓ |
| PATCH | Partial update | ✗ |
| DELETE | Remove resource | ✓ |
Status Codes
- 2xx: Success (200 OK, 201 Created)
- 3xx: Redirection (301, 304)
- 4xx: Client error (400, 404, 401)
- 5xx: Server error (500, 503)
Resource-Oriented Design
✓ GET /users - List users
✓ POST /users - Create user
✓ GET /users/123 - Get user 123
✓ PUT /users/123 - Replace user 123
✓ PATCH /users/123 - Partial update
✓ DELETE /users/123 - Delete user 123
✗ GET /getUser?id=123 - Procedural (bad)
Request/Response
# Request
GET /api/v1/users/123 HTTP/1.1
Host: api.example.com
Authorization: Bearer token
Content-Type: application/json
# Response
HTTP/1.1 200 OK
Content-Type: application/json
{
"id": 123,
"name": "John",
"email": "john@example.com"
}
Error Handling
{
"error": "Validation failed",
"details": {
"email": "Invalid email format"
},
"status": 400
}
Pagination
GET /users?page=2&limit=20
GET /users?offset=40&limit=20
GET /users?cursor=abc123
Versioning
/api/v1/users (stable)
/api/v2/users (new version)
/api/beta/users (experimental)
Best Practices
- Use appropriate methods for operations
- Meaningful status codes for responses
- Consistent naming conventions
- Pagination for large datasets
- Rate limiting to protect API
- Authentication/Authorization
- Documentation (Swagger/OpenAPI)
Express.js Example
const express = require('express');
const app = express();
// Get all users
app.get('/users', (req, res) => {
res.json(users);
});
// Get user by ID
app.get('/users/:id', (req, res) => {
const user = users.find(u => u.id == req.params.id);
res.json(user);
});
// Create user
app.post('/users', (req, res) => {
const user = req.body;
users.push(user);
res.status(201).json(user);
});
// Update user
app.patch('/users/:id', (req, res) => {
const user = users.find(u => u.id == req.params.id);
Object.assign(user, req.body);
res.json(user);
});
// Delete user
app.delete('/users/:id', (req, res) => {
users = users.filter(u => u.id != req.params.id);
res.status(204).send();
});
app.listen(3000);
Testing
# Using curl
curl -X GET http://localhost:3000/users
curl -X POST http://localhost:3000/users -H "Content-Type: application/json" -d '{"name":"John"}'
ELI10
REST API is like a restaurant menu:
- GET: View menu/food
- POST: Place new order
- PUT: Replace entire order
- PATCH: Modify order slightly
- DELETE: Cancel order
Standard ways to order without confusion!
Further Resources
GraphQL
Overview
GraphQL is a query language for APIs. Request exactly what data you need, no more, no less.
Key Differences from REST
| Aspect | REST | GraphQL |
|---|---|---|
| Endpoints | Multiple (/users, /posts, /comments) | Single (/graphql) |
| Data | Fixed shape | Client specifies shape |
| Over-fetching | Get extra fields | Only requested fields |
| Under-fetching | Need multiple requests | Single request |
Schema
Define types and their relationships:
type User {
id: ID!
name: String!
email: String!
posts: [Post!]!
age: Int
}
type Post {
id: ID!
title: String!
content: String!
author: User!
createdAt: String!
}
type Query {
user(id: ID!): User
users: [User!]!
post(id: ID!): Post
}
type Mutation {
createUser(name: String!, email: String!): User!
updateUser(id: ID!, name: String): User
deleteUser(id: ID!): Boolean!
}
Queries
Request exactly what you need:
# Simple query
query {
user(id: "1") {
name
email
}
}
# Nested query
query {
user(id: "1") {
name
posts {
title
createdAt
}
}
}
# Multiple queries
query {
user1: user(id: "1") {
name
}
user2: user(id: "2") {
name
}
}
# With variables
query GetUser($userId: ID!) {
user(id: $userId) {
name
email
posts {
title
}
}
}
Mutations
Modify data:
mutation CreateUser($name: String!, $email: String!) {
createUser(name: $name, email: $email) {
id
name
email
}
}
mutation UpdateUser($id: ID!, $name: String) {
updateUser(id: $id, name: $name) {
id
name
}
}
Resolvers
Implement schema with resolvers:
const resolvers = {
Query: {
user: (parent, args) => {
return db.users.find(u => u.id === args.id);
},
users: () => {
return db.users;
}
},
Mutation: {
createUser: (parent, args) => {
const user = { id: uuidv4(), ...args };
db.users.push(user);
return user;
}
},
User: {
posts: (parent) => {
return db.posts.filter(p => p.authorId === parent.id);
}
}
};
Apollo Server (Node.js)
const { ApolloServer, gql } = require('apollo-server');
const typeDefs = gql`
type Query {
hello: String
user(id: ID!): User
}
type User {
id: ID!
name: String!
}
`;
const resolvers = {
Query: {
hello: () => 'Hello world!',
user: (_, args) => ({ id: args.id, name: 'John' })
}
};
const server = new ApolloServer({
typeDefs,
resolvers
});
server.listen();
Advantages
✅ Request only needed data (no over-fetching) ✅ Single request for related data (no under-fetching) ✅ Strong typing with schema ✅ Introspection (explore API automatically) ✅ Development tools (GraphQL Explorer)
Disadvantages
❌ More complex than REST ❌ Query complexity attacks ❌ Caching is harder ❌ Monitoring harder ❌ Learning curve
Best Practices
- Limit query depth (prevent abuse)
- Implement timeout on queries
- Use pagination for large result sets
- Combine with REST if needed
- Monitor query performance
Pagination
query {
users(first: 10, after: "cursor123") {
edges {
node {
id
name
}
cursor
}
pageInfo {
hasNextPage
endCursor
}
}
}
ELI10
GraphQL is like ordering food:
- REST: Get whole menu as-is
- GraphQL: Ask for exactly what you want
"I'll take pasta with sauce on the side, hold the onions"
Further Resources
gRPC
gRPC is a high-performance, open-source universal RPC framework. It uses HTTP/2 for transport, Protocol Buffers as the interface description language, and provides features like authentication, load balancing, and more.
Overview
Key Features:
- HTTP/2 based transport
- Protocol Buffers for serialization
- Bidirectional streaming
- Pluggable auth, tracing, load balancing
- Language-agnostic
Protocol Buffers
// user.proto
syntax = "proto3";
package user;
service UserService {
rpc GetUser(UserRequest) returns (UserResponse);
rpc ListUsers(ListUsersRequest) returns (stream UserResponse);
}
message UserRequest {
int32 id = 1;
}
message UserResponse {
int32 id = 1;
string name = 2;
string email = 3;
}
message ListUsersRequest {
int32 page = 1;
int32 page_size = 2;
}
Server Implementation (Python)
import grpc
from concurrent import futures
import user_pb2
import user_pb2_grpc
class UserServiceServicer(user_pb2_grpc.UserServiceServicer):
def GetUser(self, request, context):
# Fetch user from database
return user_pb2.UserResponse(
id=request.id,
name="John Doe",
email="john@example.com"
)
def ListUsers(self, request, context):
# Stream users
for user in get_users():
yield user_pb2.UserResponse(
id=user.id,
name=user.name,
email=user.email
)
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
user_pb2_grpc.add_UserServiceServicer_to_server(
UserServiceServicer(), server
)
server.add_insecure_port('[::]:50051')
server.start()
server.wait_for_termination()
Client Implementation
import grpc
import user_pb2
import user_pb2_grpc
def run():
with grpc.insecure_channel('localhost:50051') as channel:
stub = user_pb2_grpc.UserServiceStub(channel)
# Unary call
response = stub.GetUser(user_pb2.UserRequest(id=1))
print(f"User: {response.name}")
# Server streaming
for user in stub.ListUsers(user_pb2.ListUsersRequest(page=1)):
print(f"User: {user.name}")
Stream Types
| Type | Description |
|---|---|
| Unary | Single request/response |
| Server streaming | Client sends one request, server streams responses |
| Client streaming | Client streams requests, server sends one response |
| Bidirectional | Both stream |
gRPC provides efficient, type-safe communication between services, ideal for microservices architectures.
DevOps & CI/CD
DevOps practices, tools, and methodologies for continuous integration, delivery, and deployment.
Topics Covered
- CI/CD: Continuous integration and deployment pipelines, automation, workflows
- Docker: Containerization, images, containers, Docker Compose
- Kubernetes: Container orchestration, deployments, services, scaling
- Terraform: Infrastructure as code, cloud provisioning
- GitHub Actions: CI/CD workflows, automation
- Monitoring: Logging, metrics, observability
- Cloud Deployment: AWS, GCP, Azure
- Infrastructure: Networking, security, scaling
Key Concepts
- Continuous Integration: Automated testing on every commit
- Continuous Delivery: Automated deployment ready
- Continuous Deployment: Automated production releases
- Infrastructure as Code: Define infra with code
Tools
- CI/CD: Jenkins, GitHub Actions, GitLab CI, CircleCI
- Container: Docker, Podman
- Orchestration: Kubernetes, Docker Swarm
- IaC: Terraform, CloudFormation, Pulumi
- Monitoring: Prometheus, ELK Stack, Datadog
Navigation
Explore each tool and practice to master DevOps.
Docker
Overview
Docker packages applications into containers - lightweight, isolated environments with all dependencies. Build once, run anywhere.
Core Concepts
Images vs Containers
- Image: Blueprint (read-only template)
- Container: Running instance of image
# Build image
docker build -t myapp:1.0 .
# Run container from image
docker run myapp:1.0
Dockerfile
# Base image
FROM python:3.9-slim
# Set working directory
WORKDIR /app
# Copy files
COPY requirements.txt .
# Install dependencies
RUN pip install -r requirements.txt
# Copy application
COPY . .
# Expose port
EXPOSE 5000
# Run command
CMD ["python", "app.py"]
Docker Commands
# Build
docker build -t myapp:1.0 .
# Run
docker run -p 8000:5000 myapp:1.0
docker run -d -p 8000:5000 myapp:1.0 # Detached
# View containers
docker ps # Running
docker ps -a # All
# View images
docker images
# Logs
docker logs container_id
# Stop container
docker stop container_id
# Remove
docker rm container_id
docker rmi image_name
Docker Compose
Multiple containers together:
version: '3.8'
services:
web:
build: .
ports:
- "8000:5000"
environment:
DATABASE_URL: postgres://db:5432/mydb
depends_on:
- db
db:
image: postgres:13
environment:
POSTGRES_PASSWORD: secret
volumes:
- postgres_data:/var/lib/postgresql/data
volumes:
postgres_data:
docker-compose up # Start all services
docker-compose down # Stop all services
docker-compose logs -f # Follow logs
Best Practices
- Small images: Use minimal base images (alpine)
- Layer caching: Order commands by change frequency
- Security: Don't run as root, use secrets
- Health checks: Monitor container health
# Good: Minimal image
FROM python:3.9-slim
RUN pip install --no-cache-dir -r requirements.txt
# Health check
HEALTHCHECK --interval=30s CMD curl -f http://localhost/health
Volumes
# Mount host directory
docker run -v /host/path:/container/path myapp
# Named volume
docker run -v myvolume:/data myapp
# View volumes
docker volume ls
ELI10
Docker is like shipping containers for code:
- Package everything needed (dependencies, code, config)
- Send it anywhere (laptop, server, cloud)
- Runs the same everywhere!
No more "it works on my machine" problems!
Further Resources
Kubernetes
Overview
Kubernetes (K8s) orchestrates containerized applications at scale, handling deployment, scaling, and networking.
Core Concepts
Pods
Smallest deployable unit (usually one container):
apiVersion: v1
kind: Pod
metadata:
name: my-pod
spec:
containers:
- name: app
image: myapp:1.0
ports:
- containerPort: 8000
Deployments
Manages replicas of pods:
apiVersion: apps/v1
kind: Deployment
metadata:
name: myapp
spec:
replicas: 3
selector:
matchLabels:
app: myapp
template:
metadata:
labels:
app: myapp
spec:
containers:
- name: myapp
image: myapp:1.0
ports:
- containerPort: 8000
Services
Expose pods to network:
apiVersion: v1
kind: Service
metadata:
name: myapp-service
spec:
selector:
app: myapp
ports:
- protocol: TCP
port: 80
targetPort: 8000
type: LoadBalancer
kubectl Commands
# Create/update
kubectl apply -f deployment.yaml
# View resources
kubectl get pods
kubectl get deployments
kubectl get services
# Describe
kubectl describe pod my-pod
# Logs
kubectl logs my-pod
# Execute
kubectl exec -it my-pod -- bash
# Delete
kubectl delete pod my-pod
kubectl delete deployment myapp
# Scale
kubectl scale deployment myapp --replicas=5
# Port forwarding
kubectl port-forward myapp-pod 8000:8000
Architecture
┌─────────────────────────┐
│ Control Plane │
│ - API Server │
│ - etcd (store) │
│ - Scheduler │
│ - Controller Manager │
└─────────────────────────┘
↓
┌─────────────────────────────────────┐
│ Worker Nodes │
│ ┌──────┐ ┌──────┐ ┌──────┐ │
│ │ Pod │ │ Pod │ │ Pod │ │
│ └──────┘ └──────┘ └──────┘ │
└─────────────────────────────────────┘
ConfigMap & Secrets
# ConfigMap (non-sensitive)
apiVersion: v1
kind: ConfigMap
metadata:
name: app-config
data:
LOG_LEVEL: "info"
DATABASE_HOST: "db.example.com"
---
# Secret (sensitive)
apiVersion: v1
kind: Secret
metadata:
name: db-secret
type: Opaque
data:
password: cGFzc3dvcmQxMjM= # Base64 encoded
Namespaces
Logical cluster partitions:
kubectl create namespace development
kubectl apply -f deployment.yaml -n development
kubectl get pods -n development
Scaling & Updates
# Manual scaling
kubectl scale deployment myapp --replicas=10
# Rolling update
kubectl set image deployment/myapp myapp=myapp:2.0
kubectl rollout status deployment/myapp
kubectl rollout undo deployment/myapp # Revert
Resource Limits
spec:
containers:
- name: myapp
resources:
requests:
memory: "64Mi"
cpu: "250m"
limits:
memory: "128Mi"
cpu: "500m"
ELI10
Kubernetes is like a smart warehouse manager:
- Receives orders (deployments)
- Assigns workers (pods)
- Keeps right number working
- Fixes broken ones automatically
- Spreads load across workers
Imagine managing 1000 containers automatically!
Further Resources
CI/CD Fundamentals
Overview
CI (Continuous Integration): Automatically test code on every commit CD (Continuous Deployment): Automatically deploy to production
Pipeline Stages
Code Commit
↓
Build (compile, package)
↓
Test (unit, integration, e2e)
↓
Deploy to Staging
↓
Manual/Automated Approval
↓
Deploy to Production
Tools
GitHub Actions
name: CI/CD
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Node.js
uses: actions/setup-node@v2
with:
node-version: '16'
- name: Install dependencies
run: npm install
- name: Run tests
run: npm test
- name: Run linter
run: npm run lint
- name: Deploy to production
if: github.ref == 'refs/heads/main'
run: npm run deploy
GitLab CI
stages:
- build
- test
- deploy
build:
stage: build
script:
- npm install
- npm run build
artifacts:
paths:
- dist/
test:
stage: test
script:
- npm install
- npm test
deploy:
stage: deploy
script:
- npm run deploy
only:
- main
Best Practices
- Automated Testing: Every commit
- Fast Feedback: Minutes, not hours
- Deploy Often: Small, frequent changes
- Monitoring: Alert on failures
- Rollback Ready: Revert quickly if needed
Pipeline as Code
Define pipeline in version control:
# .github/workflows/deploy.yml
name: Deploy
on: [push]
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- run: docker build -t myapp .
- run: docker push myapp:latest
- run: kubectl apply -f deployment.yaml
Deployment Strategies
Blue-Green
Blue (current): Production v1
Green (new): Production v2
Switch traffic instantly to v2
If issue: Switch back to v1
Canary
Release to 5% of users first
Monitor metrics
If healthy: 10% → 25% → 50% → 100%
If issues: Rollback at any stage
Rolling
Stop pod, deploy new version
Repeat for each pod
Zero downtime
Monitoring in CI/CD
# Check metrics after deploy
- name: Health check
run: |
curl -f https://api.example.com/health || exit 1
- name: Performance check
run: |
response_time=$(curl -w '%{time_total}' https://api.example.com)
if (( $(echo "$response_time > 2.0" | bc -l) )); then
echo "Slow response: $response_time seconds"
exit 1
fi
Common Issues
Flaky Tests
Tests that pass/fail randomly Solution: Fix test, increase timeout, isolate dependencies
Deployment Failures
Solution: Pre-deployment checks, canary deployments, rollback procedures
Security Vulnerabilities
Solution: Dependency scanning, static code analysis, container scanning
ELI10
CI/CD is like an assembly line:
- CI: Test each part as made
- CD: Automatically package and ship
- Monitoring: Check if delivery was successful
Catch problems BEFORE customers see them!
Further Resources
System Design
Designing large-scale distributed systems for performance, scalability, and reliability.
Topics Covered
- Scalability: Horizontal vs vertical scaling, load balancing strategies
- Caching: Cache strategies, invalidation, distributed caching
- RPC: Remote Procedure Call frameworks and patterns
- Microservices: Microservices architecture patterns, service decomposition, communication patterns
- Databases: SQL vs NoSQL, sharding, replication
- Message Queues: Asynchronous processing, event-driven architecture
- Distributed Consensus: Consistency models, CAP theorem
- Design Patterns: Common solutions for distributed systems
Key Concepts
- Throughput: Requests per second
- Latency: Response time
- Availability: Uptime percentage
- Consistency: Data correctness
- Partition Tolerance: Handling failures
Design Goals
- Reliability: Surviving failures
- Scalability: Growing with demand
- Performance: Fast responses
- Maintainability: Easy to update
Steps to Design System
- Understand requirements and constraints
- High-level architecture
- Detailed design of components
- Identify bottlenecks
- Trade-offs and optimization
Navigation
Learn principles for designing systems at scale.
Scalability
Overview
Scalability is the ability to handle increased load by adding more resources.
Vertical Scaling
Add more power to existing machines:
1 machine: 8 cores, 32GB RAM
→ Upgrade to: 16 cores, 128GB RAM
Pros: Simple, less complexity Cons: Hardware limits, single point of failure
Horizontal Scaling
Add more machines:
Machine 1: Handle requests
Machine 2: Handle requests
Machine 3: Handle requests
↓ Load Balancer ↓
Clients
Pros: Unlimited growth, fault tolerance Cons: More complexity, state management
Load Balancing
Client Request
↓
┌─ Load Balancer ─┐
↓ ↓ ↓
Server1 Server2 Server3
Algorithms:
- Round Robin: Rotate servers
- Least Connections: Route to least busy
- IP Hash: Same client to same server
- Weighted: Distribute by capacity
Database Scaling
Replication
Master-slave setup:
- Master: Writes
- Slaves: Read copies
Master (R/W)
↙ ↓ ↘
Slave1 Slave2 Slave3 (R only)
Sharding
Partition data across databases:
Shard 1: Users 1-1M
Shard 2: Users 1M-2M
Shard 3: Users 2M-3M
By User ID % 3 → Route to correct shard
Caching
Store frequently accessed data:
Client Request
↓
Check Cache (fast)
↓ miss ↓ hit
Database → Client
(slow)
Cache Invalidation:
- TTL: Expire after time
- Event-based: Invalidate on update
- LRU: Remove least used items
Common Patterns
CDN (Content Delivery Network)
Distributed servers for static content:
User in Asia → Asia CDN Server (fast)
User in US → US CDN Server (fast)
Queue Systems
Handle spikes asynchronously:
Request → Queue → Worker Pool → Database
(fast) (slow processing)
Read Replicas
Separate read and write:
Write (slow): Direct to master
Read (fast): From replicas
Metrics
| Metric | Target |
|---|---|
| Response Time | <100ms |
| Throughput | >1000 req/s |
| Uptime | >99.9% |
| Availability | 5-9s |
ELI10
Scalability is like growing a restaurant:
- Vertical: Make kitchen bigger (limited)
- Horizontal: Open more locations (unlimited)
- Load balancer: Customers split between locations
- Caching: Keep popular dishes ready
- Queues: Don't overwhelm kitchen
Design for growth from day one!
Further Resources
Caching Strategies
Overview
Caching stores frequently accessed data in fast memory to reduce latency and database load.
Cache Levels
L1: Browser cache (browser memory)
↓
L2: CDN cache (edge servers)
↓
L3: Application cache (Redis)
↓
L4: Database cache (MySQL buffer pool)
↓
Database (disk, slowest)
Caching Policies
Cache-Aside (Lazy Loading)
1. Check cache
2. If miss: Load from database
3. Store in cache
4. Return to client
def get_user(user_id):
# Check cache
cached = redis.get(f"user:{user_id}")
if cached:
return cached
# Load from DB
user = db.get_user(user_id)
# Store in cache
redis.set(f"user:{user_id}", user, ex=3600)
return user
Write-Through
Write to cache AND database simultaneously:
Update Request
↓
Cache ← updated
↓
Database ← updated
Ensures consistency but slower writes.
Write-Behind (Write-Back)
Write to cache, asynchronously to database:
Update Request
↓
Cache ← updated (fast)
↓
Queue for DB
↓
Database ← updated (later)
Fast but risk of data loss.
Invalidation Strategies
TTL (Time-To-Live)
redis.set("key", value, ex=3600) # Expires in 1 hour
Pros: Simple Cons: Stale data until expiry
Event-Based
Invalidate when data changes:
def update_user(user_id, data):
db.update_user(user_id, data)
redis.delete(f"user:{user_id}") # Invalidate
Pros: Fresh data Cons: Complex logic
LRU (Least Recently Used)
Remove least used items when full:
[recent] A B C D E [old]
Remove E if memory full
Cache Eviction Policies
| Policy | Behavior |
|---|---|
| LRU | Remove least recently used |
| LFU | Remove least frequently used |
| FIFO | Remove oldest |
| Random | Remove random |
Distributed Caching
Using Redis for distributed cache:
import redis
cache = redis.Redis(host='localhost', port=6379)
# Set
cache.set('key', 'value')
cache.setex('key', 3600, 'value') # With TTL
# Get
value = cache.get('key')
# Delete
cache.delete('key')
# Multi-key
cache.mget(['key1', 'key2', 'key3'])
Cache Stampede
Problem: Multiple requests load same expired key
3 requests arrive
Cache expired for key X
All 3 hit database (thundering herd)
Solution: Lock pattern
def get_cached(key):
value = cache.get(key)
if value:
return value
if cache.get(f"{key}:lock"):
# Wait, someone loading
return wait_for_cache(key)
# Set lock, load data
cache.set(f"{key}:lock", "1", ex=5)
value = load_from_db(key)
cache.set(key, value)
cache.delete(f"{key}:lock")
return value
Common Caching Patterns
Cache Coherence
Multiple caches have same data
Cache Penetration
Request for non-existent key hits DB repeatedly
Solution: Cache negative results
cache.set(f"user:{id}", None, ex=60)
Cache Avalanche
Many keys expire simultaneously
Solution: Randomize TTLs
ttl = 3600 + random(0, 600)
cache.set(key, value, ex=ttl)
When NOT to Cache
- Constantly changing data
- Very frequently read, rarely write
- Small datasets
- Rare access patterns
ELI10
Cache is like keeping your favorite book on your desk:
- Fast access (don't go to library)
- Runs out of space (limited shelf)
- Need to replace old books (eviction)
- Book gets outdated (invalidation)
Trade memory for speed!
Further Resources
RPC (Remote Procedure Call)
RPC is a protocol that allows a program to execute a procedure on another computer as if it were a local procedure call.
Overview
RPC abstracts network communication, making distributed computing appear like local function calls.
Key Concepts:
- Client-Server model
- Stub generation
- Marshalling/Unmarshalling
- Synchronous or asynchronous calls
Common RPC Frameworks
| Framework | Protocol | Language |
|---|---|---|
| gRPC | HTTP/2, Protobuf | Multi-language |
| JSON-RPC | HTTP, JSON | Multi-language |
| XML-RPC | HTTP, XML | Multi-language |
| Apache Thrift | Binary | Multi-language |
gRPC Example
// service.proto
service Calculator {
rpc Add(Numbers) returns (Result);
}
message Numbers {
int32 a = 1;
int32 b = 2;
}
message Result {
int32 value = 1;
}
JSON-RPC Example
// Request
{
"jsonrpc": "2.0",
"method": "add",
"params": {"a": 5, "b": 3},
"id": 1
}
// Response
{
"jsonrpc": "2.0",
"result": 8,
"id": 1
}
Advantages
- Simple interface (like local calls)
- Language-agnostic
- Abstraction of network details
- Type safety (with IDL)
Challenges
- Network failures
- Latency
- Versioning
- Error handling complexity
RPC simplifies distributed system development by providing procedure call semantics over network communication.
Microservices Architecture
Microservices is an architectural style that structures an application as a collection of loosely coupled, independently deployable services. Each service is self-contained, implements a specific business capability, and communicates with other services through well-defined APIs.
Table of Contents
- Introduction
- Core Principles
- Service Design
- Communication Patterns
- Service Discovery
- API Gateway
- Data Management
- Deployment and DevOps
- Best Practices
- Challenges and Solutions
Introduction
What are Microservices? Microservices break down a large application into smaller, independent services that:
- Run in their own processes
- Communicate via lightweight protocols (HTTP, message queues)
- Can be deployed independently
- Can use different technologies
- Are organized around business capabilities
Benefits:
- Independent deployment and scaling
- Technology diversity
- Fault isolation
- Team autonomy
- Faster development cycles
- Easier to understand and maintain small services
Challenges:
- Distributed system complexity
- Network latency and failures
- Data consistency
- Testing complexity
- Operational overhead
- Service coordination
Core Principles
1. Single Responsibility
Each service handles one business capability.
❌ Monolith: One service handles users, orders, payments, inventory
✅ Microservices:
- User Service: Authentication, profiles
- Order Service: Order management
- Payment Service: Payment processing
- Inventory Service: Stock management
2. Decentralized Data Management
Each service owns its data store.
// Each service has its own database
User Service → Users DB (PostgreSQL)
Order Service → Orders DB (MongoDB)
Inventory Service → Inventory DB (MySQL)
3. Smart Endpoints, Dumb Pipes
Services are intelligent; communication is simple.
// Services handle business logic
// Communication uses simple protocols (HTTP, AMQP)
4. Design for Failure
Expect services to fail; build resilience.
// Circuit breakers
// Retries
// Fallbacks
// Timeouts
Service Design
Domain-Driven Design
// Bounded Contexts
Order Context {
- Order
- OrderItem
- OrderStatus
}
User Context {
- User
- Profile
- Authentication
}
Payment Context {
- Payment
- Transaction
- PaymentMethod
}
Service Size
// Small enough to:
// - Be maintained by a small team (2-pizza team)
// - Be rewritten in 2-4 weeks
// - Have a clear purpose
// Large enough to:
// - Provide business value
// - Minimize inter-service communication
// - Have a clear domain boundary
Example Service Structure
order-service/
├── src/
│ ├── api/
│ │ ├── routes/
│ │ └── controllers/
│ ├── domain/
│ │ ├── models/
│ │ └── services/
│ ├── infrastructure/
│ │ ├── database/
│ │ └── messaging/
│ ├── config/
│ └── main.ts
├── tests/
├── Dockerfile
├── package.json
└── README.md
Communication Patterns
Synchronous Communication (REST/HTTP)
Example: Order Service calling User Service
// order-service/userClient.js
const axios = require('axios');
class UserServiceClient {
constructor(baseURL) {
this.client = axios.create({
baseURL: baseURL || process.env.USER_SERVICE_URL,
timeout: 5000
});
}
async getUser(userId) {
try {
const response = await this.client.get(`/users/${userId}`);
return response.data;
} catch (error) {
if (error.code === 'ECONNABORTED') {
throw new Error('User service timeout');
}
throw error;
}
}
}
// Usage in order service
async function createOrder(orderData) {
const userClient = new UserServiceClient();
const user = await userClient.getUser(orderData.userId);
if (!user) {
throw new Error('User not found');
}
// Create order logic...
}
Asynchronous Communication (Message Queues)
Example: Event-Driven Communication
// order-service/publisher.js
const { Kafka } = require('kafkajs');
const kafka = new Kafka({
clientId: 'order-service',
brokers: ['kafka:9092']
});
const producer = kafka.producer();
async function publishOrderCreated(order) {
await producer.send({
topic: 'order.created',
messages: [{
key: `order:${order.id}`,
value: JSON.stringify({
orderId: order.id,
userId: order.userId,
items: order.items,
total: order.total,
timestamp: Date.now()
})
}]
});
}
// inventory-service/consumer.js
const consumer = kafka.consumer({
groupId: 'inventory-service'
});
async function start() {
await consumer.subscribe({ topic: 'order.created' });
await consumer.run({
eachMessage: async ({ message }) => {
const order = JSON.parse(message.value.toString());
console.log('Reserving inventory for order:', order.orderId);
await reserveInventory(order.items);
// Publish inventory.reserved event
await publishInventoryReserved(order.orderId);
}
});
}
API Composition Pattern
// api-gateway/orderComposer.js
class OrderComposer {
constructor(userService, orderService, inventoryService) {
this.userService = userService;
this.orderService = orderService;
this.inventoryService = inventoryService;
}
async getOrderDetails(orderId) {
// Parallel requests
const [order, user, inventory] = await Promise.all([
this.orderService.getOrder(orderId),
this.userService.getUser(order.userId),
this.inventoryService.checkAvailability(order.items)
]);
return {
order,
user: {
id: user.id,
name: user.name,
email: user.email
},
inventory
};
}
}
Service Discovery
Client-Side Discovery
// service-registry.js
class ServiceRegistry {
constructor() {
this.services = new Map();
}
register(serviceName, instance) {
if (!this.services.has(serviceName)) {
this.services.set(serviceName, []);
}
this.services.get(serviceName).push(instance);
}
discover(serviceName) {
const instances = this.services.get(serviceName) || [];
if (instances.length === 0) {
throw new Error(`No instances available for ${serviceName}`);
}
// Round-robin load balancing
return instances[Math.floor(Math.random() * instances.length)];
}
}
// Usage
const registry = new ServiceRegistry();
registry.register('user-service', { host: 'localhost', port: 3001 });
registry.register('user-service', { host: 'localhost', port: 3002 });
const instance = registry.discover('user-service');
Consul Integration
const Consul = require('consul');
const consul = new Consul({
host: 'consul-server',
port: 8500
});
// Register service
async function registerService() {
await consul.agent.service.register({
name: 'order-service',
id: `order-service-${process.env.INSTANCE_ID}`,
address: process.env.SERVICE_HOST,
port: parseInt(process.env.SERVICE_PORT),
check: {
http: `http://${process.env.SERVICE_HOST}:${process.env.SERVICE_PORT}/health`,
interval: '10s'
}
});
}
// Discover service
async function discoverService(serviceName) {
const result = await consul.health.service({
service: serviceName,
passing: true
});
const instances = result.map(item => ({
address: item.Service.Address,
port: item.Service.Port
}));
return instances;
}
API Gateway
Basic API Gateway
const express = require('express');
const proxy = require('express-http-proxy');
const app = express();
// Service URLs
const USER_SERVICE = process.env.USER_SERVICE_URL;
const ORDER_SERVICE = process.env.ORDER_SERVICE_URL;
const PRODUCT_SERVICE = process.env.PRODUCT_SERVICE_URL;
// Authentication middleware
app.use(async (req, res, next) => {
const token = req.headers.authorization;
if (!token) {
return res.status(401).json({ error: 'Unauthorized' });
}
try {
const user = await verifyToken(token);
req.user = user;
next();
} catch (error) {
res.status(401).json({ error: 'Invalid token' });
}
});
// Rate limiting
const rateLimit = require('express-rate-limit');
const limiter = rateLimit({
windowMs: 15 * 60 * 1000,
max: 100
});
app.use(limiter);
// Route to services
app.use('/api/users', proxy(USER_SERVICE));
app.use('/api/orders', proxy(ORDER_SERVICE));
app.use('/api/products', proxy(PRODUCT_SERVICE));
// Aggregation endpoint
app.get('/api/dashboard', async (req, res) => {
try {
const [user, orders, recommendations] = await Promise.all([
axios.get(`${USER_SERVICE}/users/${req.user.id}`),
axios.get(`${ORDER_SERVICE}/users/${req.user.id}/orders`),
axios.get(`${PRODUCT_SERVICE}/recommendations/${req.user.id}`)
]);
res.json({
user: user.data,
recentOrders: orders.data,
recommendations: recommendations.data
});
} catch (error) {
res.status(500).json({ error: 'Failed to load dashboard' });
}
});
app.listen(3000);
Data Management
Database Per Service
// Each service has its own database
services/
├── user-service/
│ └── database: PostgreSQL
├── order-service/
│ └── database: MongoDB
└── inventory-service/
└── database: MySQL
Saga Pattern (Distributed Transactions)
Choreography-Based Saga:
// Order Service
async function createOrder(orderData) {
const order = await Order.create({
...orderData,
status: 'PENDING'
});
// Publish event
await publishEvent('order.created', order);
return order;
}
// Inventory Service
consumer.on('order.created', async (order) => {
try {
await reserveInventory(order.items);
await publishEvent('inventory.reserved', { orderId: order.id });
} catch (error) {
await publishEvent('inventory.failed', {
orderId: order.id,
error: error.message
});
}
});
// Payment Service
consumer.on('inventory.reserved', async ({ orderId }) => {
try {
await processPayment(orderId);
await publishEvent('payment.completed', { orderId });
} catch (error) {
await publishEvent('payment.failed', { orderId, error: error.message });
}
});
// Order Service - Handle success/failure
consumer.on('payment.completed', async ({ orderId }) => {
await Order.update({ id: orderId }, { status: 'CONFIRMED' });
});
consumer.on('payment.failed', async ({ orderId }) => {
await Order.update({ id: orderId }, { status: 'CANCELLED' });
await publishEvent('order.cancelled', { orderId });
});
// Inventory Service - Compensating transaction
consumer.on('order.cancelled', async ({ orderId }) => {
await releaseInventory(orderId);
});
CQRS (Command Query Responsibility Segregation)
// Write Model (Commands)
class OrderWriteService {
async createOrder(command) {
const order = await Order.create(command);
// Publish event
await eventBus.publish('OrderCreated', {
orderId: order.id,
userId: order.userId,
items: order.items
});
return order.id;
}
}
// Read Model (Queries)
class OrderReadService {
constructor(readDatabase) {
this.db = readDatabase;
}
async getOrderById(orderId) {
return await this.db.orders.findOne({ id: orderId });
}
async getOrdersByUser(userId) {
return await this.db.orders.find({ userId });
}
}
// Event Handler (updates read model)
eventBus.on('OrderCreated', async (event) => {
await readDatabase.orders.insert({
id: event.orderId,
userId: event.userId,
items: event.items,
createdAt: new Date()
});
});
Deployment and DevOps
Docker Compose
version: '3.8'
services:
api-gateway:
build: ./api-gateway
ports:
- "3000:3000"
environment:
USER_SERVICE_URL: http://user-service:3001
ORDER_SERVICE_URL: http://order-service:3002
depends_on:
- user-service
- order-service
user-service:
build: ./user-service
environment:
DATABASE_URL: postgresql://postgres:password@user-db:5432/users
depends_on:
- user-db
order-service:
build: ./order-service
environment:
DATABASE_URL: mongodb://order-db:27017/orders
KAFKA_BROKERS: kafka:9092
depends_on:
- order-db
- kafka
user-db:
image: postgres:15
environment:
POSTGRES_PASSWORD: password
order-db:
image: mongo:6
kafka:
image: confluentinc/cp-kafka:latest
Kubernetes
# order-service-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: order-service
spec:
replicas: 3
selector:
matchLabels:
app: order-service
template:
metadata:
labels:
app: order-service
spec:
containers:
- name: order-service
image: myregistry/order-service:1.0.0
ports:
- containerPort: 3000
env:
- name: DATABASE_URL
valueFrom:
secretKeyRef:
name: order-service-secrets
key: database-url
resources:
requests:
memory: "128Mi"
cpu: "100m"
limits:
memory: "512Mi"
cpu: "500m"
livenessProbe:
httpGet:
path: /health
port: 3000
initialDelaySeconds: 30
periodSeconds: 10
readinessProbe:
httpGet:
path: /ready
port: 3000
initialDelaySeconds: 5
periodSeconds: 5
---
apiVersion: v1
kind: Service
metadata:
name: order-service
spec:
selector:
app: order-service
ports:
- port: 80
targetPort: 3000
type: ClusterIP
Best Practices
1. Circuit Breaker Pattern
const CircuitBreaker = require('opossum');
const options = {
timeout: 3000,
errorThresholdPercentage: 50,
resetTimeout: 30000
};
const breaker = new CircuitBreaker(callExternalService, options);
breaker.fallback(() => ({ fallback: 'value' }));
breaker.on('open', () => console.log('Circuit opened'));
breaker.on('halfOpen', () => console.log('Circuit half-open'));
breaker.on('close', () => console.log('Circuit closed'));
async function callExternalService() {
const response = await axios.get('http://external-service/api');
return response.data;
}
// Usage
try {
const result = await breaker.fire();
console.log(result);
} catch (error) {
console.error('Service call failed');
}
2. Health Checks
const express = require('express');
const app = express();
app.get('/health', async (req, res) => {
const health = {
uptime: process.uptime(),
message: 'OK',
timestamp: Date.now()
};
try {
// Check database connection
await database.ping();
health.database = 'connected';
} catch (error) {
health.database = 'disconnected';
health.message = 'Degraded';
return res.status(503).json(health);
}
res.json(health);
});
app.get('/ready', async (req, res) => {
try {
// Check if service is ready to accept traffic
await database.ping();
await cache.ping();
res.json({ status: 'ready' });
} catch (error) {
res.status(503).json({ status: 'not ready' });
}
});
3. Distributed Tracing
const { trace } = require('@opentelemetry/api');
const { NodeTracerProvider } = require('@opentelemetry/sdk-trace-node');
const provider = new NodeTracerProvider();
provider.register();
const tracer = trace.getTracer('order-service');
async function createOrder(orderData) {
const span = tracer.startSpan('createOrder');
try {
// Add attributes
span.setAttribute('user.id', orderData.userId);
span.setAttribute('order.total', orderData.total);
// Business logic
const order = await Order.create(orderData);
span.setStatus({ code: 0 }); // OK
return order;
} catch (error) {
span.setStatus({
code: 2, // ERROR
message: error.message
});
throw error;
} finally {
span.end();
}
}
4. Logging
const winston = require('winston');
const logger = winston.createLogger({
level: 'info',
format: winston.format.json(),
defaultMeta: {
service: 'order-service',
version: '1.0.0'
},
transports: [
new winston.transports.File({ filename: 'error.log', level: 'error' }),
new winston.transports.File({ filename: 'combined.log' })
]
});
// Structured logging
logger.info('Order created', {
orderId: order.id,
userId: order.userId,
total: order.total,
timestamp: Date.now()
});
Challenges and Solutions
Challenge 1: Data Consistency
Solution: Use eventual consistency with event-driven architecture
// Use events to propagate data changes
await publishEvent('user.updated', { userId, email: newEmail });
// Other services listen and update their local views
consumer.on('user.updated', async (event) => {
await updateLocalUserCache(event.userId, event.email);
});
Challenge 2: Service Communication Failures
Solution: Implement retry logic with exponential backoff
async function callServiceWithRetry(fn, maxRetries = 3) {
for (let i = 0; i < maxRetries; i++) {
try {
return await fn();
} catch (error) {
if (i === maxRetries - 1) throw error;
await sleep(Math.pow(2, i) * 1000);
}
}
}
Challenge 3: Testing
Solution: Use contract testing and integration tests
// Contract test (using Pact)
const { Pact } = require('@pact-foundation/pact');
const provider = new Pact({
consumer: 'order-service',
provider: 'user-service'
});
describe('User Service Contract', () => {
it('should get user by ID', async () => {
await provider.addInteraction({
state: 'user 123 exists',
uponReceiving: 'a request for user 123',
withRequest: {
method: 'GET',
path: '/users/123'
},
willRespondWith: {
status: 200,
body: { id: 123, name: 'John' }
}
});
// Test your client code
const user = await userClient.getUser(123);
expect(user.id).toBe(123);
});
});
Resources
Books:
- Building Microservices by Sam Newman
- Microservices Patterns by Chris Richardson
- Release It! by Michael Nygard
Frameworks:
Tools:
- Kubernetes
- Istio - Service Mesh
- Consul - Service Discovery
- Jaeger - Distributed Tracing
Learning:
Mobile Development
Cross-platform and native mobile application development frameworks and best practices.
Topics Covered
Cross-Platform Frameworks
-
React Native: Build native mobile apps using React and JavaScript/TypeScript
- Component architecture and navigation
- Platform-specific code
- Native modules and bridges
- Performance optimization
- State management
- Testing and debugging
-
Flutter: Google's UI toolkit for building natively compiled applications
- Dart programming language
- Widget-based architecture
- State management (Provider, Riverpod, Bloc)
- Platform channels for native code
- Material Design and Cupertino widgets
- Hot reload and development workflow
Platform Comparison
| Feature | React Native | Flutter |
|---|---|---|
| Language | JavaScript/TypeScript | Dart |
| Performance | Near-native | Native |
| UI | Native components | Custom rendering |
| Community | Very large | Growing rapidly |
| Learning Curve | Easier (if you know React) | Moderate |
| Hot Reload | Yes | Yes |
| Code Sharing | High (with web) | High |
Development Workflow
- Setup: Install development environment and tools
- Design: Create UI/UX mockups
- Development: Write code with hot reload
- Testing: Unit, integration, and E2E tests
- Debugging: Use developer tools
- Deployment: Build and publish to app stores
Key Concepts
- Cross-platform: Write once, run on iOS and Android
- Native modules: Access platform-specific features
- State management: Handle app state efficiently
- Navigation: Implement screen transitions
- Performance: Optimize for mobile devices
- Platform differences: Handle iOS and Android specifics
Mobile App Architecture
- MVVM (Model-View-ViewModel): Separate UI from business logic
- Clean Architecture: Layered approach with dependency inversion
- BLoC (Business Logic Component): Event-driven architecture
- Redux/MobX: Centralized state management
Best Practices
-
Performance
- Optimize images and assets
- Use lazy loading
- Minimize re-renders
- Profile and monitor performance
-
User Experience
- Follow platform guidelines (iOS HIG, Material Design)
- Handle offline mode gracefully
- Provide feedback for actions
- Optimize for different screen sizes
-
Security
- Secure storage for sensitive data
- API authentication and authorization
- SSL pinning
- Code obfuscation
-
Testing
- Write unit tests for business logic
- Integration tests for features
- E2E tests for critical flows
- Test on multiple devices
Navigation
Explore each framework to build production-ready mobile applications for iOS and Android.
React Native
React Native is a popular JavaScript framework for building native mobile applications using React. It allows developers to use React along with native platform capabilities to build iOS and Android apps from a single codebase, with the ability to share code between platforms.
Table of Contents
- Introduction
- Setup and Installation
- Core Components
- Styling
- Navigation
- State Management
- API and Data Fetching
- Native Modules
- Performance Optimization
- Testing
- Deployment
Introduction
Key Features:
- Cross-platform development (iOS and Android)
- Native performance
- Hot reloading for fast development
- Large ecosystem and community
- Code reusability with web React
- Native API access
- Over-the-air (OTA) updates
Use Cases:
- Cross-platform mobile apps
- MVP development
- Apps requiring frequent updates
- Teams with JavaScript/React expertise
- Apps with shared business logic
Setup and Installation
Prerequisites
# Install Node.js (14+)
node --version
npm --version
# Install Watchman (macOS)
brew install watchman
# Install Xcode (macOS, for iOS development)
# Install Android Studio (for Android development)
Create New Project
# Using React Native CLI
npx react-native init MyApp
cd MyApp
# Using Expo (recommended for beginners)
npx create-expo-app MyApp
cd MyApp
npx expo start
Running the App
# React Native CLI
# iOS
npx react-native run-ios
# Android
npx react-native run-android
# Expo
npx expo start
# Then press 'i' for iOS or 'a' for Android
Core Components
View and Text
import React from 'react';
import { View, Text, StyleSheet } from 'react-native';
export default function App() {
return (
<View style={styles.container}>
<Text style={styles.title}>Hello React Native!</Text>
<Text style={styles.subtitle}>Welcome to mobile development</Text>
</View>
);
}
const styles = StyleSheet.create({
container: {
flex: 1,
justifyContent: 'center',
alignItems: 'center',
backgroundColor: '#fff',
},
title: {
fontSize: 24,
fontWeight: 'bold',
marginBottom: 10,
},
subtitle: {
fontSize: 16,
color: '#666',
},
});
Button and TouchableOpacity
import { Button, TouchableOpacity, Alert } from 'react-native';
function MyComponent() {
const handlePress = () => {
Alert.alert('Button Pressed', 'You clicked the button!');
};
return (
<View>
{/* Basic Button */}
<Button title="Click Me" onPress={handlePress} color="#007AFF" />
{/* Custom Touchable */}
<TouchableOpacity
style={styles.customButton}
onPress={handlePress}
activeOpacity={0.7}
>
<Text style={styles.buttonText}>Custom Button</Text>
</TouchableOpacity>
</View>
);
}
const styles = StyleSheet.create({
customButton: {
backgroundColor: '#007AFF',
padding: 15,
borderRadius: 8,
alignItems: 'center',
marginTop: 10,
},
buttonText: {
color: 'white',
fontSize: 16,
fontWeight: 'bold',
},
});
TextInput
import { useState } from 'react';
import { TextInput } from 'react-native';
function LoginForm() {
const [email, setEmail] = useState('');
const [password, setPassword] = useState('');
return (
<View style={styles.form}>
<TextInput
style={styles.input}
placeholder="Email"
value={email}
onChangeText={setEmail}
keyboardType="email-address"
autoCapitalize="none"
autoCorrect={false}
/>
<TextInput
style={styles.input}
placeholder="Password"
value={password}
onChangeText={setPassword}
secureTextEntry
autoCapitalize="none"
/>
<Button
title="Login"
onPress={() => handleLogin(email, password)}
/>
</View>
);
}
ScrollView and FlatList
import { ScrollView, FlatList } from 'react-native';
// ScrollView - for small lists
function SimpleList() {
return (
<ScrollView>
{items.map((item) => (
<View key={item.id} style={styles.item}>
<Text>{item.name}</Text>
</View>
))}
</ScrollView>
);
}
// FlatList - for large lists (better performance)
function OptimizedList() {
const DATA = [
{ id: '1', title: 'Item 1' },
{ id: '2', title: 'Item 2' },
{ id: '3', title: 'Item 3' },
];
const renderItem = ({ item }) => (
<View style={styles.item}>
<Text style={styles.title}>{item.title}</Text>
</View>
);
return (
<FlatList
data={DATA}
renderItem={renderItem}
keyExtractor={(item) => item.id}
onRefresh={() => refreshData()}
refreshing={loading}
/>
);
}
Image
import { Image } from 'react-native';
function ImageExample() {
return (
<View>
{/* Local Image */}
<Image
source={require('./assets/logo.png')}
style={{ width: 100, height: 100 }}
resizeMode="contain"
/>
{/* Remote Image */}
<Image
source={{ uri: 'https://example.com/image.jpg' }}
style={{ width: 200, height: 200 }}
resizeMode="cover"
/>
</View>
);
}
Styling
StyleSheet
import { StyleSheet } from 'react-native';
const styles = StyleSheet.create({
container: {
flex: 1,
padding: 20,
backgroundColor: '#f5f5f5',
},
card: {
backgroundColor: 'white',
borderRadius: 8,
padding: 15,
marginBottom: 10,
shadowColor: '#000',
shadowOffset: { width: 0, height: 2 },
shadowOpacity: 0.1,
shadowRadius: 4,
elevation: 3, // Android shadow
},
title: {
fontSize: 18,
fontWeight: 'bold',
marginBottom: 5,
},
});
Flexbox Layout
// Flex Direction
<View style={{ flex: 1, flexDirection: 'row' }}>
<View style={{ flex: 1, backgroundColor: 'red' }} />
<View style={{ flex: 2, backgroundColor: 'blue' }} />
</View>
// Justify Content
<View style={{ flex: 1, justifyContent: 'space-between' }}>
<Text>Top</Text>
<Text>Middle</Text>
<Text>Bottom</Text>
</View>
// Align Items
<View style={{ flex: 1, alignItems: 'center' }}>
<Text>Centered Horizontally</Text>
</View>
Responsive Design
import { Dimensions, Platform } from 'react-native';
const { width, height } = Dimensions.get('window');
const styles = StyleSheet.create({
container: {
width: width * 0.9, // 90% of screen width
padding: width < 350 ? 10 : 20, // Conditional padding
},
image: {
width: width - 40,
height: (width - 40) * 0.6, // Aspect ratio
},
platformSpecific: {
...Platform.select({
ios: {
shadowColor: '#000',
shadowOffset: { width: 0, height: 2 },
shadowOpacity: 0.3,
shadowRadius: 4,
},
android: {
elevation: 5,
},
}),
},
});
Navigation
React Navigation
npm install @react-navigation/native
npm install react-native-screens react-native-safe-area-context
npm install @react-navigation/stack
Stack Navigator:
import { NavigationContainer } from '@react-navigation/native';
import { createStackNavigator } from '@react-navigation/stack';
const Stack = createStackNavigator();
function HomeScreen({ navigation }) {
return (
<View style={styles.container}>
<Text>Home Screen</Text>
<Button
title="Go to Details"
onPress={() => navigation.navigate('Details', { itemId: 42 })}
/>
</View>
);
}
function DetailsScreen({ route, navigation }) {
const { itemId } = route.params;
return (
<View style={styles.container}>
<Text>Details Screen</Text>
<Text>Item ID: {itemId}</Text>
<Button title="Go Back" onPress={() => navigation.goBack()} />
</View>
);
}
export default function App() {
return (
<NavigationContainer>
<Stack.Navigator
initialRouteName="Home"
screenOptions={{
headerStyle: { backgroundColor: '#007AFF' },
headerTintColor: '#fff',
headerTitleStyle: { fontWeight: 'bold' },
}}
>
<Stack.Screen
name="Home"
component={HomeScreen}
options={{ title: 'Welcome' }}
/>
<Stack.Screen name="Details" component={DetailsScreen} />
</Stack.Navigator>
</NavigationContainer>
);
}
Tab Navigator:
import { createBottomTabNavigator } from '@react-navigation/bottom-tabs';
import Ionicons from 'react-native-vector-icons/Ionicons';
const Tab = createBottomTabNavigator();
export default function App() {
return (
<NavigationContainer>
<Tab.Navigator
screenOptions={({ route }) => ({
tabBarIcon: ({ focused, color, size }) => {
let iconName;
if (route.name === 'Home') {
iconName = focused ? 'home' : 'home-outline';
} else if (route.name === 'Settings') {
iconName = focused ? 'settings' : 'settings-outline';
}
return <Ionicons name={iconName} size={size} color={color} />;
},
tabBarActiveTintColor: '#007AFF',
tabBarInactiveTintColor: 'gray',
})}
>
<Tab.Screen name="Home" component={HomeScreen} />
<Tab.Screen name="Settings" component={SettingsScreen} />
</Tab.Navigator>
</NavigationContainer>
);
}
State Management
Context API
import React, { createContext, useContext, useState } from 'react';
const AuthContext = createContext();
export function AuthProvider({ children }) {
const [user, setUser] = useState(null);
const login = async (email, password) => {
// API call
const response = await fetch('/api/login', {
method: 'POST',
body: JSON.stringify({ email, password }),
});
const data = await response.json();
setUser(data.user);
};
const logout = () => {
setUser(null);
};
return (
<AuthContext.Provider value={{ user, login, logout }}>
{children}
</AuthContext.Provider>
);
}
export function useAuth() {
return useContext(AuthContext);
}
// Usage
function ProfileScreen() {
const { user, logout } = useAuth();
return (
<View>
<Text>Welcome, {user?.name}</Text>
<Button title="Logout" onPress={logout} />
</View>
);
}
Redux Toolkit
npm install @reduxjs/toolkit react-redux
import { createSlice, configureStore } from '@reduxjs/toolkit';
import { Provider, useSelector, useDispatch } from 'react-redux';
// Slice
const counterSlice = createSlice({
name: 'counter',
initialState: { value: 0 },
reducers: {
increment: (state) => {
state.value += 1;
},
decrement: (state) => {
state.value -= 1;
},
},
});
export const { increment, decrement } = counterSlice.actions;
// Store
const store = configureStore({
reducer: {
counter: counterSlice.reducer,
},
});
// Component
function Counter() {
const count = useSelector((state) => state.counter.value);
const dispatch = useDispatch();
return (
<View>
<Text>{count}</Text>
<Button title="+" onPress={() => dispatch(increment())} />
<Button title="-" onPress={() => dispatch(decrement())} />
</View>
);
}
// App
export default function App() {
return (
<Provider store={store}>
<Counter />
</Provider>
);
}
API and Data Fetching
Fetch API
import { useState, useEffect } from 'react';
function UserProfile({ userId }) {
const [user, setUser] = useState(null);
const [loading, setLoading] = useState(true);
const [error, setError] = useState(null);
useEffect(() => {
fetchUser();
}, [userId]);
const fetchUser = async () => {
try {
setLoading(true);
const response = await fetch(`https://api.example.com/users/${userId}`);
if (!response.ok) {
throw new Error('Failed to fetch user');
}
const data = await response.json();
setUser(data);
} catch (err) {
setError(err.message);
} finally {
setLoading(false);
}
};
if (loading) return <Text>Loading...</Text>;
if (error) return <Text>Error: {error}</Text>;
return (
<View>
<Text>{user.name}</Text>
<Text>{user.email}</Text>
</View>
);
}
Axios
npm install axios
import axios from 'axios';
const api = axios.create({
baseURL: 'https://api.example.com',
timeout: 10000,
headers: {
'Content-Type': 'application/json',
},
});
// Interceptors
api.interceptors.request.use(
(config) => {
const token = getToken();
if (token) {
config.headers.Authorization = `Bearer ${token}`;
}
return config;
},
(error) => Promise.reject(error)
);
api.interceptors.response.use(
(response) => response,
(error) => {
if (error.response?.status === 401) {
// Handle unauthorized
logout();
}
return Promise.reject(error);
}
);
// Usage
const fetchUsers = async () => {
const response = await api.get('/users');
return response.data;
};
const createUser = async (userData) => {
const response = await api.post('/users', userData);
return response.data;
};
Native Modules
Accessing Device Features
// Camera
import { Camera } from 'expo-camera';
function CameraScreen() {
const [hasPermission, setHasPermission] = useState(null);
useEffect(() => {
(async () => {
const { status } = await Camera.requestCameraPermissionsAsync();
setHasPermission(status === 'granted');
})();
}, []);
if (hasPermission === null) {
return <View />;
}
return (
<Camera style={{ flex: 1 }} type={Camera.Constants.Type.back}>
{/* Camera UI */}
</Camera>
);
}
// Location
import * as Location from 'expo-location';
const getLocation = async () => {
const { status } = await Location.requestForegroundPermissionsAsync();
if (status !== 'granted') {
return;
}
const location = await Location.getCurrentPositionAsync({});
console.log(location.coords);
};
// Notifications
import * as Notifications from 'expo-notifications';
const sendNotification = async () => {
await Notifications.scheduleNotificationAsync({
content: {
title: "You've got mail!",
body: 'Here is the notification body',
},
trigger: { seconds: 2 },
});
};
Performance Optimization
Memoization
import React, { useMemo, useCallback } from 'react';
function ExpensiveComponent({ data }) {
// Memoize expensive calculations
const processedData = useMemo(() => {
return data.map((item) => {
// Expensive operation
return processItem(item);
});
}, [data]);
// Memoize callbacks
const handlePress = useCallback(() => {
console.log('Button pressed');
}, []);
return (
<View>
{processedData.map((item) => (
<Item key={item.id} data={item} onPress={handlePress} />
))}
</View>
);
}
// Memo component
const Item = React.memo(({ data, onPress }) => {
return (
<TouchableOpacity onPress={onPress}>
<Text>{data.name}</Text>
</TouchableOpacity>
);
});
FlatList Optimization
<FlatList
data={data}
renderItem={renderItem}
keyExtractor={(item) => item.id}
// Performance optimizations
removeClippedSubviews={true}
maxToRenderPerBatch={10}
updateCellsBatchingPeriod={50}
initialNumToRender={10}
windowSize={10}
getItemLayout={(data, index) => ({
length: ITEM_HEIGHT,
offset: ITEM_HEIGHT * index,
index,
})}
/>
Testing
Jest and React Native Testing Library
npm install --save-dev @testing-library/react-native
import { render, fireEvent } from '@testing-library/react-native';
import Counter from './Counter';
describe('Counter', () => {
it('renders correctly', () => {
const { getByText } = render(<Counter />);
expect(getByText('Count: 0')).toBeTruthy();
});
it('increments counter', () => {
const { getByText, getByTestId } = render(<Counter />);
const button = getByTestId('increment-button');
fireEvent.press(button);
expect(getByText('Count: 1')).toBeTruthy();
});
});
Deployment
iOS
# Build for release
npx react-native run-ios --configuration Release
# Or with Xcode
# Open ios/YourApp.xcworkspace
# Select Generic iOS Device
# Product > Archive
# Upload to App Store
Android
# Generate release APK
cd android
./gradlew assembleRelease
# APK location:
# android/app/build/outputs/apk/release/app-release.apk
# Generate AAB (App Bundle)
./gradlew bundleRelease
Resources
Official Documentation:
Learning:
Tools:
Flutter
Flutter is Google's UI toolkit for building natively compiled applications for mobile, web, and desktop from a single codebase. It uses the Dart programming language and provides a rich set of pre-designed widgets for creating beautiful, high-performance applications.
Table of Contents
- Introduction
- Setup and Installation
- Dart Basics
- Widgets
- Layouts
- State Management
- Navigation and Routing
- Networking
- Local Storage
- Testing
- Deployment
Introduction
Key Features:
- Single codebase for iOS, Android, web, and desktop
- Fast development with hot reload
- Beautiful, customizable widgets
- Native performance
- Rich animation support
- Strong typing with Dart
- Extensive package ecosystem
Use Cases:
- Cross-platform mobile apps
- Material Design and Cupertino (iOS-style) apps
- High-performance UIs
- Apps with complex animations
- MVPs and startups
- Enterprise applications
Setup and Installation
Install Flutter
macOS:
# Download Flutter SDK
# https://flutter.dev/docs/get-started/install/macos
# Add to PATH
export PATH="$PATH:`pwd`/flutter/bin"
# Run doctor
flutter doctor
# Install Xcode
# Install Android Studio
Windows:
# Download Flutter SDK
# https://flutter.dev/docs/get-started/install/windows
# Add to PATH
# Run flutter doctor
Create New Project
# Create project
flutter create my_app
cd my_app
# Run on iOS
flutter run -d ios
# Run on Android
flutter run -d android
# Run on web
flutter run -d chrome
Project Structure
my_app/
├── android/ # Android-specific code
├── ios/ # iOS-specific code
├── lib/ # Dart source code
│ ├── main.dart # Entry point
│ ├── screens/ # Screen widgets
│ ├── widgets/ # Reusable widgets
│ ├── models/ # Data models
│ ├── services/ # API services
│ └── utils/ # Utilities
├── test/ # Tests
├── pubspec.yaml # Dependencies
└── README.md
Dart Basics
Variables and Types
// Variables
var name = 'John';
String city = 'New York';
int age = 30;
double height = 5.9;
bool isActive = true;
// Final and const
final String country = 'USA'; // Runtime constant
const double pi = 3.14159; // Compile-time constant
// Null safety
String? nullableName; // Can be null
String nonNullName = 'John'; // Cannot be null
// Late initialization
late String description;
Functions
// Basic function
String greet(String name) {
return 'Hello, $name!';
}
// Arrow function
String greet(String name) => 'Hello, $name!';
// Optional parameters
String greet(String name, [String? title]) {
return title != null ? 'Hello, $title $name!' : 'Hello, $name!';
}
// Named parameters
String greet({required String name, String title = 'Mr.'}) {
return 'Hello, $title $name!';
}
// Async function
Future<String> fetchData() async {
await Future.delayed(Duration(seconds: 2));
return 'Data loaded';
}
Classes
class User {
String name;
int age;
// Constructor
User(this.name, this.age);
// Named constructor
User.guest() : name = 'Guest', age = 0;
// Method
String introduce() {
return 'I am $name, $age years old';
}
// Getter
bool get isAdult => age >= 18;
// Setter
set updateAge(int newAge) {
if (newAge > 0) age = newAge;
}
}
// Usage
var user = User('John', 30);
print(user.introduce());
print(user.isAdult);
Lists and Maps
// Lists
List<String> names = ['John', 'Jane', 'Bob'];
names.add('Alice');
names.remove('Bob');
// Maps
Map<String, int> ages = {
'John': 30,
'Jane': 28,
};
ages['Bob'] = 35;
// Iteration
names.forEach((name) => print(name));
ages.forEach((key, value) => print('$key: $value'));
Widgets
Stateless Widget
import 'package:flutter/material.dart';
class WelcomeScreen extends StatelessWidget {
final String title;
const WelcomeScreen({Key? key, required this.title}) : super(key: key);
@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
title: Text(title),
),
body: Center(
child: Text(
'Welcome to Flutter!',
style: TextStyle(fontSize: 24),
),
),
);
}
}
Stateful Widget
class CounterScreen extends StatefulWidget {
@override
_CounterScreenState createState() => _CounterScreenState();
}
class _CounterScreenState extends State<CounterScreen> {
int _counter = 0;
void _incrementCounter() {
setState(() {
_counter++;
});
}
@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
title: Text('Counter'),
),
body: Center(
child: Column(
mainAxisAlignment: MainAxisAlignment.center,
children: [
Text('Count:'),
Text(
'$_counter',
style: TextStyle(fontSize: 48, fontWeight: FontWeight.bold),
),
],
),
),
floatingActionButton: FloatingActionButton(
onPressed: _incrementCounter,
child: Icon(Icons.add),
),
);
}
}
Common Widgets
// Text
Text(
'Hello Flutter',
style: TextStyle(
fontSize: 24,
fontWeight: FontWeight.bold,
color: Colors.blue,
),
)
// Image
Image.network('https://example.com/image.jpg')
Image.asset('assets/logo.png')
// Button
ElevatedButton(
onPressed: () {
print('Button pressed');
},
child: Text('Click Me'),
)
// TextField
TextField(
decoration: InputDecoration(
labelText: 'Email',
hintText: 'Enter your email',
border: OutlineInputBorder(),
),
onChanged: (value) {
print(value);
},
)
// Container
Container(
width: 200,
height: 100,
padding: EdgeInsets.all(16),
margin: EdgeInsets.all(8),
decoration: BoxDecoration(
color: Colors.blue,
borderRadius: BorderRadius.circular(12),
boxShadow: [
BoxShadow(
color: Colors.grey.withOpacity(0.5),
spreadRadius: 2,
blurRadius: 5,
offset: Offset(0, 3),
),
],
),
child: Text('Styled Container'),
)
Layouts
Column and Row
Column(
mainAxisAlignment: MainAxisAlignment.center,
crossAxisAlignment: CrossAxisAlignment.start,
children: [
Text('First'),
Text('Second'),
Text('Third'),
],
)
Row(
mainAxisAlignment: MainAxisAlignment.spaceEvenly,
children: [
Icon(Icons.star),
Icon(Icons.favorite),
Icon(Icons.thumb_up),
],
)
Stack
Stack(
children: [
Container(
width: 200,
height: 200,
color: Colors.blue,
),
Positioned(
top: 20,
left: 20,
child: Text('Overlayed Text'),
),
],
)
ListView
// Simple ListView
ListView(
children: [
ListTile(
leading: Icon(Icons.person),
title: Text('John Doe'),
subtitle: Text('john@example.com'),
trailing: Icon(Icons.arrow_forward),
onTap: () {
print('Tapped');
},
),
ListTile(
leading: Icon(Icons.person),
title: Text('Jane Smith'),
),
],
)
// ListView.builder (for large lists)
ListView.builder(
itemCount: items.length,
itemBuilder: (context, index) {
return ListTile(
title: Text(items[index].name),
);
},
)
// ListView.separated
ListView.separated(
itemCount: items.length,
itemBuilder: (context, index) => ListTile(
title: Text(items[index]),
),
separatorBuilder: (context, index) => Divider(),
)
GridView
GridView.count(
crossAxisCount: 2,
children: List.generate(20, (index) {
return Card(
child: Center(
child: Text('Item $index'),
),
);
}),
)
GridView.builder(
gridDelegate: SliverGridDelegateWithFixedCrossAxisCount(
crossAxisCount: 3,
crossAxisSpacing: 10,
mainAxisSpacing: 10,
),
itemCount: items.length,
itemBuilder: (context, index) {
return Card(
child: Image.network(items[index].imageUrl),
);
},
)
State Management
Provider
# pubspec.yaml
dependencies:
provider: ^6.0.0
import 'package:provider/provider.dart';
// Model
class Counter with ChangeNotifier {
int _count = 0;
int get count => _count;
void increment() {
_count++;
notifyListeners();
}
void decrement() {
_count--;
notifyListeners();
}
}
// Main app
void main() {
runApp(
ChangeNotifierProvider(
create: (context) => Counter(),
child: MyApp(),
),
);
}
// Consumer widget
class CounterScreen extends StatelessWidget {
@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(title: Text('Counter')),
body: Center(
child: Consumer<Counter>(
builder: (context, counter, child) {
return Text(
'${counter.count}',
style: TextStyle(fontSize: 48),
);
},
),
),
floatingActionButton: FloatingActionButton(
onPressed: () {
context.read<Counter>().increment();
},
child: Icon(Icons.add),
),
);
}
}
Riverpod
dependencies:
flutter_riverpod: ^2.0.0
import 'package:flutter_riverpod/flutter_riverpod.dart';
// Provider
final counterProvider = StateProvider<int>((ref) => 0);
// Main app
void main() {
runApp(
ProviderScope(
child: MyApp(),
),
);
}
// Consumer widget
class CounterScreen extends ConsumerWidget {
@override
Widget build(BuildContext context, WidgetRef ref) {
final counter = ref.watch(counterProvider);
return Scaffold(
body: Center(
child: Text('$counter'),
),
floatingActionButton: FloatingActionButton(
onPress: () {
ref.read(counterProvider.notifier).state++;
},
child: Icon(Icons.add),
),
);
}
}
Navigation and Routing
Basic Navigation
// Navigate to new screen
Navigator.push(
context,
MaterialPageRoute(builder: (context) => SecondScreen()),
);
// Navigate back
Navigator.pop(context);
// Navigate with data
Navigator.push(
context,
MaterialPageRoute(
builder: (context) => DetailScreen(id: 123),
),
);
// Return data
final result = await Navigator.push(
context,
MaterialPageRoute(builder: (context) => SecondScreen()),
);
Named Routes
// Define routes
MaterialApp(
initialRoute: '/',
routes: {
'/': (context) => HomeScreen(),
'/details': (context) => DetailsScreen(),
'/profile': (context) => ProfileScreen(),
},
)
// Navigate
Navigator.pushNamed(context, '/details');
// With arguments
Navigator.pushNamed(
context,
'/details',
arguments: {'id': 123},
);
// Extract arguments
class DetailsScreen extends StatelessWidget {
@override
Widget build(BuildContext context) {
final args = ModalRoute.of(context)!.settings.arguments as Map;
final id = args['id'];
return Scaffold(
appBar: AppBar(title: Text('Details $id')),
);
}
}
Networking
HTTP Package
dependencies:
http: ^0.13.0
import 'package:http/http.dart' as http;
import 'dart:convert';
// GET request
Future<List<User>> fetchUsers() async {
final response = await http.get(
Uri.parse('https://api.example.com/users'),
);
if (response.statusCode == 200) {
List<dynamic> data = jsonDecode(response.body);
return data.map((json) => User.fromJson(json)).toList();
} else {
throw Exception('Failed to load users');
}
}
// POST request
Future<User> createUser(String name, String email) async {
final response = await http.post(
Uri.parse('https://api.example.com/users'),
headers: {'Content-Type': 'application/json'},
body: jsonEncode({
'name': name,
'email': email,
}),
);
if (response.statusCode == 201) {
return User.fromJson(jsonDecode(response.body));
} else {
throw Exception('Failed to create user');
}
}
// FutureBuilder
class UsersList extends StatelessWidget {
@override
Widget build(BuildContext context) {
return FutureBuilder<List<User>>(
future: fetchUsers(),
builder: (context, snapshot) {
if (snapshot.connectionState == ConnectionState.waiting) {
return CircularProgressIndicator();
} else if (snapshot.hasError) {
return Text('Error: ${snapshot.error}');
} else if (snapshot.hasData) {
return ListView.builder(
itemCount: snapshot.data!.length,
itemBuilder: (context, index) {
return ListTile(
title: Text(snapshot.data![index].name),
);
},
);
} else {
return Text('No data');
}
},
);
}
}
Local Storage
Shared Preferences
dependencies:
shared_preferences: ^2.0.0
import 'package:shared_preferences/shared_preferences.dart';
// Save data
Future<void> saveData() async {
final prefs = await SharedPreferences.getInstance();
await prefs.setString('username', 'John');
await prefs.setInt('age', 30);
await prefs.setBool('isLoggedIn', true);
}
// Read data
Future<String?> readData() async {
final prefs = await SharedPreferences.getInstance();
return prefs.getString('username');
}
// Remove data
Future<void> removeData() async {
final prefs = await SharedPreferences.getInstance();
await prefs.remove('username');
}
SQLite
dependencies:
sqflite: ^2.0.0
path: ^1.8.0
import 'package:sqflite/sqflite.dart';
import 'package:path/path.dart';
class DatabaseHelper {
static final DatabaseHelper instance = DatabaseHelper._init();
static Database? _database;
DatabaseHelper._init();
Future<Database> get database async {
if (_database != null) return _database!;
_database = await _initDB('users.db');
return _database!;
}
Future<Database> _initDB(String filePath) async {
final dbPath = await getDatabasesPath();
final path = join(dbPath, filePath);
return await openDatabase(
path,
version: 1,
onCreate: _createDB,
);
}
Future _createDB(Database db, int version) async {
await db.execute('''
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
email TEXT NOT NULL
)
''');
}
Future<int> insert(Map<String, dynamic> row) async {
final db = await database;
return await db.insert('users', row);
}
Future<List<Map<String, dynamic>>> queryAll() async {
final db = await database;
return await db.query('users');
}
Future<int> update(Map<String, dynamic> row) async {
final db = await database;
int id = row['id'];
return await db.update('users', row, where: 'id = ?', whereArgs: [id]);
}
Future<int> delete(int id) async {
final db = await database;
return await db.delete('users', where: 'id = ?', whereArgs: [id]);
}
}
Testing
Unit Tests
// test/counter_test.dart
import 'package:flutter_test/flutter_test.dart';
import 'package:my_app/counter.dart';
void main() {
test('Counter increments', () {
final counter = Counter();
counter.increment();
expect(counter.count, 1);
});
test('Counter decrements', () {
final counter = Counter();
counter.decrement();
expect(counter.count, -1);
});
}
Widget Tests
import 'package:flutter/material.dart';
import 'package:flutter_test/flutter_test.dart';
import 'package:my_app/main.dart';
void main() {
testWidgets('Counter increments smoke test', (WidgetTester tester) async {
// Build the widget
await tester.pumpWidget(MyApp());
// Verify initial state
expect(find.text('0'), findsOneWidget);
expect(find.text('1'), findsNothing);
// Tap the '+' icon and trigger a frame
await tester.tap(find.byIcon(Icons.add));
await tester.pump();
// Verify counter incremented
expect(find.text('0'), findsNothing);
expect(find.text('1'), findsOneWidget);
});
}
Deployment
Android
# Build APK
flutter build apk --release
# Build App Bundle (recommended)
flutter build appbundle --release
# Split APKs by ABI
flutter build apk --split-per-abi
iOS
# Build for iOS
flutter build ios --release
# Or use Xcode
# Open ios/Runner.xcworkspace
# Product > Archive
# Upload to App Store
Configuration
android/app/build.gradle:
android {
defaultConfig {
applicationId "com.example.myapp"
minSdkVersion 21
targetSdkVersion 33
versionCode 1
versionName "1.0.0"
}
signingConfigs {
release {
storeFile file("upload-keystore.jks")
storePassword System.getenv("KEYSTORE_PASSWORD")
keyAlias "upload"
keyPassword System.getenv("KEY_PASSWORD")
}
}
}
ios/Runner/Info.plist:
<key>CFBundleDisplayName</key>
<string>My App</string>
<key>CFBundleVersion</key>
<string>1</string>
<key>CFBundleShortVersionString</key>
<string>1.0.0</string>
Resources
Official Documentation:
Packages:
- pub.dev - Official package repository
- Flutter Awesome - Curated packages
Learning:
Tools:
- Flutter DevTools
- DartPad - Online IDE
- FlutterFire - Firebase integration
Testing & Quality
Testing strategies, frameworks, and best practices for ensuring code quality and reliability.
Topics Covered
- Unit Testing: Testing individual functions and classes in isolation
- Integration Testing: Testing component interactions and APIs
- pytest: Python testing framework with fixtures and parametrization
- TDD: Test-driven development approaches and best practices
- Test Frameworks: pytest, Jest, unittest
- Mocking: Isolating code under test
- Coverage: Measuring test completeness
- Debugging: Finding and fixing issues
- Code Quality: Linting, formatting, static analysis
- Performance Testing: Load and stress testing
Testing Pyramid
E2E Tests (few)
Integration Tests (more)
Unit Tests (many)
Ratio: 70% unit, 20% integration, 10% e2e
Test Types
- Unit: Individual functions/classes
- Integration: Multiple components together
- End-to-End: Full user workflows
- Performance: Load and speed
- Security: Vulnerability detection
Best Practices
- Fast: Tests run quickly
- Independent: No dependencies between tests
- Repeatable: Consistent results
- Self-checking: Pass/fail obvious
- Timely: Written with code
Navigation
Learn strategies to build reliable, quality software.
Unit Testing
Overview
Unit testing verifies individual functions or classes work correctly in isolation.
Python - pytest
# test_calculator.py
import pytest
from calculator import add, divide
def test_add():
assert add(2, 3) == 5
assert add(-1, 1) == 0
def test_add_floats():
assert add(0.1, 0.2) == pytest.approx(0.3)
def test_divide_by_zero():
with pytest.raises(ValueError):
divide(10, 0)
@pytest.fixture
def calculator():
"""Setup fixture"""
return Calculator()
def test_with_fixture(calculator):
assert calculator.add(2, 3) == 5
@pytest.mark.parametrize("x,y,expected", [
(2, 3, 5),
(0, 0, 0),
(-1, 1, 0),
])
def test_add_multiple(x, y, expected):
assert add(x, y) == expected
JavaScript - Jest
// calculator.test.js
describe('Calculator', () => {
test('add function', () => {
expect(add(2, 3)).toBe(5);
});
test('divide by zero throws error', () => {
expect(() => divide(10, 0)).toThrow();
});
test('floating point', () => {
expect(add(0.1, 0.2)).toBeCloseTo(0.3);
});
});
Java - JUnit
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class CalculatorTest {
private Calculator calc = new Calculator();
@Test
void testAdd() {
assertEquals(5, calc.add(2, 3));
}
@Test
void testDivideByZero() {
assertThrows(ArithmeticException.class,
() -> calc.divide(10, 0));
}
@BeforeEach
void setup() {
calc = new Calculator();
}
}
Mocking
Isolate code under test:
from unittest.mock import Mock, patch
@patch('module.external_api')
def test_with_mock(mock_api):
mock_api.return_value = {"status": "ok"}
result = my_function()
assert result == "success"
mock_api.assert_called_once()
Best Practices
- One assertion per test (or related)
- Arrange-Act-Assert pattern
- Descriptive names:
test_add_positive_numbers - Test behavior, not implementation
- Test edge cases and errors
Coverage
# Python coverage
pytest --cov=myapp tests/
# JavaScript coverage
npm test -- --coverage
Target: 80%+ coverage
Common Assertions
assert x == y # Equality
assert x > y # Comparison
assert x is None # Identity
assert x in list # Membership
with raises(Exception): # Exception
function()
ELI10
Unit tests are like checking individual pieces:
- Test each part separately
- Make sure it works alone
- Catch problems early!
Like quality control in a factory!
Further Resources
Integration Testing
Integration testing verifies that different modules or services work together correctly. Unlike unit tests that test individual components in isolation, integration tests validate interactions between components.
Overview
Integration tests validate:
- API endpoints
- Database interactions
- External service integrations
- Component interactions
- End-to-end workflows
Testing Strategies
Bottom-Up Approach
# Test data layer
def test_database_connection():
db = connect_to_database()
assert db.is_connected()
# Test service layer with real database
def test_user_service():
service = UserService(real_database)
user = service.create_user("test@example.com")
assert user.email == "test@example.com"
# Test API layer with real services
def test_api_endpoint():
response = client.post("/users", json={"email": "test@example.com"})
assert response.status_code == 201
Top-Down Approach
# Test API first with mocked services
def test_api_with_mocks():
with mock_user_service():
response = client.post("/users", json={"email": "test@example.com"})
assert response.status_code == 201
# Then test with real services
def test_api_with_real_services():
response = client.post("/users", json={"email": "test@example.com"})
user = db.query("SELECT * FROM users WHERE email = ?", "test@example.com")
assert user is not None
API Testing
# Flask example
from flask import Flask
from flask.testing import FlaskClient
def test_api_endpoints(client: FlaskClient):
# POST request
response = client.post('/api/users', json={
'username': 'testuser',
'email': 'test@example.com'
})
assert response.status_code == 201
data = response.get_json()
user_id = data['id']
# GET request
response = client.get(f'/api/users/{user_id}')
assert response.status_code == 200
assert response.json['username'] == 'testuser'
# PUT request
response = client.put(f'/api/users/{user_id}', json={
'email': 'newemail@example.com'
})
assert response.status_code == 200
# DELETE request
response = client.delete(f'/api/users/{user_id}')
assert response.status_code == 204
Database Testing
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
@pytest.fixture(scope="function")
def db_session():
# Create test database
engine = create_engine('sqlite:///:memory:')
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
session = Session()
yield session
session.close()
def test_user_crud(db_session):
# Create
user = User(username="test", email="test@example.com")
db_session.add(user)
db_session.commit()
# Read
retrieved = db_session.query(User).filter_by(username="test").first()
assert retrieved.email == "test@example.com"
# Update
retrieved.email = "updated@example.com"
db_session.commit()
# Delete
db_session.delete(retrieved)
db_session.commit()
assert db_session.query(User).count() == 0
Docker Compose for Testing
# docker-compose.test.yml
version: '3.8'
services:
postgres:
image: postgres:15
environment:
POSTGRES_DB: testdb
POSTGRES_USER: test
POSTGRES_PASSWORD: test
ports:
- "5432:5432"
redis:
image: redis:7
ports:
- "6379:6379"
app:
build: .
depends_on:
- postgres
- redis
environment:
DATABASE_URL: postgresql://test:test@postgres:5432/testdb
REDIS_URL: redis://redis:6379
command: pytest tests/integration/
# Run integration tests
docker-compose -f docker-compose.test.yml up --abort-on-container-exit
Test Fixtures and Setup
import pytest
@pytest.fixture(scope="session")
def app():
"""Create application for testing"""
app = create_app('testing')
return app
@pytest.fixture(scope="session")
def client(app):
"""Create test client"""
return app.test_client()
@pytest.fixture(scope="function")
def clean_database(db_session):
"""Clean database before each test"""
db_session.query(User).delete()
db_session.query(Order).delete()
db_session.commit()
yield
db_session.rollback()
def test_with_clean_db(client, clean_database):
response = client.post('/users', json={'username': 'test'})
assert response.status_code == 201
Mocking External Services
from unittest.mock import patch, Mock
def test_external_api_integration():
with patch('requests.get') as mock_get:
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {'data': 'test'}
mock_get.return_value = mock_response
result = fetch_external_data()
assert result['data'] == 'test'
mock_get.assert_called_once()
Best Practices
- Isolate tests: Each test should be independent
- Use test databases: Never test against production
- Clean state: Reset database/state between tests
- Test realistic scenarios: Use production-like data
- Fast feedback: Keep tests reasonably fast
- CI/CD integration: Run automatically on commits
- Test error cases: Not just happy paths
- Use containers: Docker for consistent environments
Quick Reference
| Aspect | Approach |
|---|---|
| Database | Use test DB or transactions |
| External APIs | Mock or use test endpoints |
| File system | Use temp directories |
| Time | Mock datetime |
| Network | Use test servers or mocks |
Integration tests ensure your system components work together correctly, catching issues that unit tests might miss.
Test-Driven Development (TDD)
Overview
TDD: Write tests BEFORE writing code. Red → Green → Refactor cycle.
Red-Green-Refactor Cycle
1. Red: Write Failing Test
def test_add_positive_numbers():
assert add(2, 3) == 5
2. Green: Write Minimal Code
def add(a, b):
return 5 # Hardcoded to pass test
3. Refactor: Improve Code
def add(a, b):
return a + b # Proper implementation
Benefits
✅ Better design (code written to be testable) ✅ Fewer bugs (test before shipping) ✅ Confidence (safe to refactor) ✅ Documentation (tests show usage) ✅ Less debugging (catch issues early)
Example: Calculator
Step 1: Red
class TestCalculator:
def test_add(self):
calc = Calculator()
assert calc.add(2, 3) == 5
Step 2: Green
class Calculator:
def add(self, a, b):
return a + b
Step 3: Refactor
class Calculator:
def add(self, a, b):
"""Add two numbers"""
if not isinstance(a, (int, float)):
raise TypeError("a must be number")
return a + b
TDD Best Practices
- Start simple: Test one behavior
- One assertion per test (usually)
- Clear names:
test_add_positive_numbers - Arrange-Act-Assert
def test_withdraw():
# Arrange
account = Account(1000)
# Act
account.withdraw(200)
# Assert
assert account.balance == 800
- Don't skip red phase: Ensures test can fail
Working Test Example
# calculator.py - EMPTY (start)
# test_calculator.py
def test_multiply():
# Test fails: function doesn't exist (RED)
assert multiply(3, 4) == 12
# calculator.py - implement
def multiply(a, b):
return a * b
# test passes (GREEN)
# Refactor if needed
Anti-patterns
❌ Writing all tests at once ❌ Over-engineering the implementation ❌ Ignoring red phase ❌ Poorly named tests ❌ Testing implementation, not behavior
Coverage with TDD
TDD naturally leads to high coverage:
# Typical TDD: 90%+ coverage
# Non-TDD: 20-40% coverage
TDD vs BDD
TDD: Tests focus on unit behavior
test_add_positive_numbers()
test_add_negative_numbers()
BDD: Tests focus on business behavior
test_user_can_withdraw_money()
test_system_prevents_overdraft()
Tools
- pytest: Python testing
- Jest: JavaScript testing
- JUnit: Java testing
- RSpec: Ruby testing
ELI10
TDD is like building with blueprints:
- Draw blueprint (write test)
- Build to match (write code)
- Improve design (refactor)
Never start building without a plan!
Further Resources
pytest
pytest is a mature, feature-rich testing framework for Python that makes it easy to write simple tests, yet scales to support complex functional testing.
Installation
pip install pytest
pip install pytest pytest-cov pytest-mock pytest-xdist
# Verify
pytest --version
Basic Usage
# Run all tests
pytest
# Run specific file
pytest test_example.py
# Run specific test
pytest test_example.py::test_function
# Run with verbose output
pytest -v
# Run with coverage
pytest --cov=myapp tests/
# Parallel execution
pytest -n 4
Writing Tests
# test_example.py
# Simple test
def test_addition():
assert 1 + 1 == 2
# Test with setup
def test_list():
my_list = [1, 2, 3]
assert len(my_list) == 3
assert 2 in my_list
# Test exceptions
import pytest
def test_division_by_zero():
with pytest.raises(ZeroDivisionError):
1 / 0
# Parametrized test
@pytest.mark.parametrize("input,expected", [
(1, 2),
(2, 3),
(3, 4),
])
def test_increment(input, expected):
assert input + 1 == expected
Fixtures
import pytest
# Basic fixture
@pytest.fixture
def sample_data():
return [1, 2, 3, 4, 5]
def test_sum(sample_data):
assert sum(sample_data) == 15
# Fixture with setup/teardown
@pytest.fixture
def database_connection():
# Setup
conn = create_connection()
yield conn
# Teardown
conn.close()
# Scope: function (default), class, module, package, session
@pytest.fixture(scope="module")
def expensive_resource():
return load_expensive_data()
# Autouse fixture
@pytest.fixture(autouse=True)
def setup_test():
print("Setting up test")
yield
print("Tearing down test")
Markers
import pytest
# Skip test
@pytest.mark.skip(reason="Not implemented yet")
def test_feature():
pass
# Skip conditionally
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python 3.8+")
def test_modern_feature():
pass
# Expected to fail
@pytest.mark.xfail
def test_known_bug():
assert False
# Custom marker
@pytest.mark.slow
def test_slow_operation():
pass
# Run specific markers
# pytest -m slow
# pytest -m "not slow"
Mocking
from unittest.mock import Mock, patch, MagicMock
def test_with_mock():
mock_obj = Mock()
mock_obj.method.return_value = 42
assert mock_obj.method() == 42
# Patch function
def test_with_patch():
with patch('module.function') as mock_func:
mock_func.return_value = 'mocked'
result = module.function()
assert result == 'mocked'
# pytest-mock plugin
def test_with_mocker(mocker):
mock = mocker.patch('module.function')
mock.return_value = 'mocked'
assert module.function() == 'mocked'
Quick Reference
| Command | Description |
|---|---|
pytest | Run all tests |
pytest -v | Verbose output |
pytest -k pattern | Run tests matching pattern |
pytest -m marker | Run tests with marker |
pytest --cov | Coverage report |
pytest -x | Stop on first failure |
pytest --pdb | Drop into debugger on failure |
pytest is the de facto standard for Python testing with its simple syntax, powerful features, and extensive plugin ecosystem.
Debugging
This directory contains guides for debugging software at various levels.
Contents
- GDB - GNU Debugger for C/C++ applications
- Core Dumps - Analyzing program crashes
- Linux Kernel - Kernel-level debugging techniques
Common Debugging Workflow
- Reproduce the issue - Consistent reproduction is key
- Gather information - Logs, error messages, core dumps
- Isolate the problem - Narrow down the scope
- Form hypothesis - What could cause this?
- Test hypothesis - Use debuggers, logs, tests
- Fix and verify - Implement fix and confirm
Tools Overview
| Tool | Purpose | Level |
|---|---|---|
| gdb | Interactive debugging | Application |
| valgrind | Memory errors | Application |
| strace | System call tracing | Application/Kernel |
| ltrace | Library call tracing | Application |
| perf | Performance profiling | Application/Kernel |
| ftrace | Function tracing | Kernel |
| dmesg | Kernel messages | Kernel |
Effective debugging combines tools, techniques, and systematic thinking.
GDB (GNU Debugger)
GDB is the GNU Project debugger, allowing you to see what is going on inside a program while it executes or what it was doing at the moment it crashed. It's an essential tool for debugging C, C++, and other compiled languages.
Overview
GDB provides extensive facilities for tracing and altering program execution, including breakpoints, watchpoints, examining variables, and manipulating program state.
Key Features:
- Set breakpoints and watchpoints
- Step through code line by line
- Examine and modify variables
- Analyze core dumps
- Remote debugging
- Multi-threaded debugging
- Reverse debugging
- Python scripting support
Installation
# Ubuntu/Debian
sudo apt update
sudo apt install gdb
# macOS (or use lldb)
brew install gdb
# CentOS/RHEL
sudo yum install gdb
# Arch Linux
sudo pacman -S gdb
# Verify installation
gdb --version
Compiling for Debugging
# Compile with debug symbols (-g flag)
gcc -g program.c -o program
g++ -g program.cpp -o program
# Disable optimization for better debugging
gcc -g -O0 program.c -o program
# With all warnings
gcc -g -O0 -Wall -Wextra program.c -o program
# For C++ with debug symbols
g++ -g -std=c++17 program.cpp -o program
Basic Usage
Starting GDB
# Start GDB with program
gdb ./program
# With arguments
gdb --args ./program arg1 arg2
# Attach to running process
gdb -p <pid>
gdb attach <pid>
# Analyze core dump
gdb ./program core
# Quiet mode (no intro message)
gdb -q ./program
Basic Commands
# Running the program
(gdb) run # Start program
(gdb) run arg1 arg2 # Start with arguments
(gdb) start # Start and break at main()
(gdb) continue # Continue execution (c)
(gdb) kill # Kill running program
(gdb) quit # Exit GDB (q)
# Breakpoints
(gdb) break main # Break at function
(gdb) break main.c:42 # Break at line in file
(gdb) break *0x400500 # Break at address
(gdb) tbreak main # Temporary breakpoint
(gdb) info breakpoints # List breakpoints (info b)
(gdb) delete 1 # Delete breakpoint 1 (d 1)
(gdb) delete # Delete all breakpoints
(gdb) disable 1 # Disable breakpoint 1
(gdb) enable 1 # Enable breakpoint 1
# Stepping
(gdb) step # Step into (s)
(gdb) next # Step over (n)
(gdb) finish # Run until function returns
(gdb) until 50 # Run until line 50
(gdb) stepi # Step one instruction (si)
(gdb) nexti # Next instruction (ni)
# Examining code
(gdb) list # Show source code (l)
(gdb) list main # List function
(gdb) list 42 # List around line 42
(gdb) disassemble # Show assembly
(gdb) disassemble main # Disassemble function
# Stack and frames
(gdb) backtrace # Show call stack (bt)
(gdb) frame 0 # Switch to frame 0 (f 0)
(gdb) up # Move up stack frame
(gdb) down # Move down stack frame
(gdb) info frame # Current frame info
(gdb) info args # Function arguments
(gdb) info locals # Local variables
Examining Variables
# Print variables
(gdb) print variable # Print variable (p)
(gdb) print *pointer # Dereference pointer
(gdb) print array[5] # Array element
(gdb) print struct.member # Structure member
# Different formats
(gdb) print/x variable # Hexadecimal
(gdb) print/d variable # Decimal
(gdb) print/t variable # Binary
(gdb) print/c variable # Character
(gdb) print/f variable # Float
(gdb) print/s string_ptr # String
# Display (auto-print on each stop)
(gdb) display variable # Auto-display variable
(gdb) info display # Show display list
(gdb) undisplay 1 # Remove display 1
# Examine memory
(gdb) x/10x $rsp # Examine 10 hex words at stack pointer
(gdb) x/10i main # Examine 10 instructions at main
(gdb) x/s string_ptr # Examine string
(gdb) x/10b buffer # Examine 10 bytes
# Format: x/[count][format][size] address
# Format: x=hex, d=decimal, i=instruction, s=string, c=char
# Size: b=byte, h=halfword, w=word, g=giant (8 bytes)
# Set variables
(gdb) set variable x = 42 # Set variable value
(gdb) set $i = 0 # Set convenience variable
Watchpoints
# Watch for changes
(gdb) watch variable # Break when variable changes
(gdb) rwatch variable # Break when variable is read
(gdb) awatch variable # Break on read or write
# Conditional watchpoint
(gdb) watch x if x > 100
# Info and delete
(gdb) info watchpoints # List watchpoints
(gdb) delete 2 # Delete watchpoint 2
Conditional Breakpoints
# Set conditional breakpoint
(gdb) break main.c:42 if x == 5
# Add condition to existing breakpoint
(gdb) condition 1 x == 5
# Remove condition
(gdb) condition 1
# Commands to execute at breakpoint
(gdb) commands 1
> print x
> continue
> end
# Ignore breakpoint N times
(gdb) ignore 1 10 # Ignore first 10 hits
Thread Debugging
# Thread information
(gdb) info threads # List all threads
(gdb) thread 3 # Switch to thread 3
(gdb) thread apply all bt # Backtrace all threads
(gdb) thread apply all print x
# Thread-specific breakpoints
(gdb) break main.c:42 thread 2
# Non-stop mode (continue while other threads stop)
(gdb) set non-stop on
Core Dump Analysis
# Generate core dump
ulimit -c unlimited # Enable core dumps
# Debug core dump
gdb ./program core
# In GDB
(gdb) bt # See where it crashed
(gdb) frame 0 # Examine crash frame
(gdb) print variable # Check variable values
(gdb) info registers # CPU registers at crash
Advanced Features
Reverse Debugging
# Record execution
(gdb) record # Start recording
(gdb) record stop # Stop recording
# Reverse execution
(gdb) reverse-step # Step backward (rs)
(gdb) reverse-next # Next backward (rn)
(gdb) reverse-continue # Continue backward (rc)
(gdb) reverse-finish # Reverse to function call
Checkpoints
# Save program state
(gdb) checkpoint # Create checkpoint
(gdb) info checkpoints # List checkpoints
(gdb) restart 1 # Restore checkpoint 1
(gdb) delete checkpoint 1 # Delete checkpoint
Python Scripting
# Python in GDB
(gdb) python print("Hello from GDB")
# Load Python script
(gdb) source script.py
# Python example
(gdb) python
> for i in range(5):
> gdb.execute("print $i++")
> end
GDB Configuration
.gdbinit File
# ~/.gdbinit
set history save on
set history size 10000
set history filename ~/.gdb_history
set print pretty on
set print array on
set print array-indexes on
set python print-stack full
# Auto-load local .gdbinit
set auto-load safe-path /
# Custom commands
define phead
print *($arg0)->head
end
define ptail
print *($arg0)->tail
end
GDB Dashboard
# Install GDB Dashboard
wget -P ~ https://git.io/.gdbinit
# Or with curl
curl -sSL https://git.io/.gdbinit > ~/.gdbinit
# Customization in ~/.gdbinit.d/init
Common Patterns
Debugging Segmentation Fault
# Run program
(gdb) run
# When it crashes
Program received signal SIGSEGV, Segmentation fault.
# Check where it crashed
(gdb) backtrace
# Examine the failing instruction
(gdb) frame 0
(gdb) list
# Check variables
(gdb) print pointer
(gdb) print *pointer # This might fail if NULL
# Check registers
(gdb) info registers
Finding Memory Leaks
# Set breakpoint at allocation
(gdb) break malloc
(gdb) commands
> backtrace
> continue
> end
# Set breakpoint at free
(gdb) break free
(gdb) commands
> backtrace
> continue
> end
# Or use Valgrind instead
Debugging Infinite Loop
# Start program
(gdb) run
# Interrupt (Ctrl+C)
^C
Program received signal SIGINT
# Check where it's stuck
(gdb) backtrace
(gdb) list
# Set breakpoint and check variable changes
(gdb) break main.c:loop_line
(gdb) commands
> print loop_var
> continue
> end
Catching Signals
# Catch specific signal
(gdb) catch signal SIGSEGV
# Catch all signals
(gdb) catch signal all
# Info signals
(gdb) info signals
# Handle signal (pass, nopass, stop, nostop, print, noprint)
(gdb) handle SIGINT nostop print pass
Remote Debugging
GDB Server
# On remote machine
gdbserver :1234 ./program
# Or attach to running process
gdbserver :1234 --attach <pid>
# On local machine
gdb ./program
(gdb) target remote remote-host:1234
(gdb) continue
Serial/UART Debugging
# Connect via serial port
gdb ./program
(gdb) target remote /dev/ttyUSB0
# Set baud rate (if needed, in .gdbinit)
set serial baud 115200
TUI Mode (Text User Interface)
# Start TUI mode
(gdb) tui enable
(gdb) Ctrl+X A # Toggle TUI
# TUI layouts
(gdb) layout src # Source code
(gdb) layout asm # Assembly
(gdb) layout split # Source and assembly
(gdb) layout regs # Registers
# Window focus
(gdb) focus cmd # Focus command window
(gdb) focus src # Focus source window
# Refresh display
(gdb) Ctrl+L # Refresh screen
Useful Tricks
Pretty Printing
# Enable pretty printing
(gdb) set print pretty on
(gdb) set print array on
(gdb) set print array-indexes on
# STL pretty printers (C++)
(gdb) python
import sys
sys.path.insert(0, '/usr/share/gcc/python')
from libstdcxx.v6.printers import register_libstdcxx_printers
register_libstdcxx_printers(None)
end
# Now print STL containers nicely
(gdb) print my_vector
(gdb) print my_map
Logging
# Enable logging
(gdb) set logging on # Logs to gdb.txt
(gdb) set logging file mylog.txt
(gdb) set logging overwrite on
# Log and display
(gdb) set logging redirect off
Macros
# ~/.gdbinit
define plist
set $node = $arg0
while $node != 0
print *$node
set $node = $node->next
end
end
# Usage
(gdb) plist head
Function Breakpoints
# Break on all functions matching pattern
(gdb) rbreak ^my_.* # All functions starting with my_
# Break on exception throw (C++)
(gdb) catch throw
# Break on system calls
(gdb) catch syscall write
Debugging Optimized Code
# Problems with -O2, -O3
# Variables optimized away
# Inlining makes stepping difficult
# Solutions:
# 1. Compile with -Og (optimize for debugging)
gcc -g -Og program.c -o program
# 2. Disable specific optimizations
gcc -g -O2 -fno-inline program.c -o program
# 3. Use volatile for critical variables
volatile int debug_var;
# In GDB, skip inlined functions
(gdb) skip -rfu ^std::
Integration with Other Tools
Valgrind and GDB
# Run program under Valgrind with GDB server
valgrind --vgdb=yes --vgdb-error=0 ./program
# In another terminal
gdb ./program
(gdb) target remote | vgdb
GDB with Make
# Makefile
debug: program
gdb ./program
.PHONY: debug
GDB in VSCode
// .vscode/launch.json
{
"version": "0.2.0",
"configurations": [
{
"name": "GDB Debug",
"type": "cppdbg",
"request": "launch",
"program": "${workspaceFolder}/program",
"args": [],
"stopAtEntry": false,
"cwd": "${workspaceFolder}",
"environment": [],
"externalConsole": false,
"MIMode": "gdb",
"setupCommands": [
{
"description": "Enable pretty-printing",
"text": "-enable-pretty-printing",
"ignoreFailures": true
}
]
}
]
}
Troubleshooting
# Can't see source code
(gdb) directory /path/to/source
# Symbols not loaded
# Ensure compiled with -g
# Check symbols loaded
(gdb) info sources
# Can't set breakpoint
# Check function exists
(gdb) info functions pattern
# Program behavior different in GDB
# Try without breakpoints
# Timing-sensitive bugs
# GDB hangs
# Check for infinite loops in pretty printers
(gdb) set print elements 100
# Can't debug strip binary
# Need unstripped version or separate debug symbols
Quick Reference
| Command | Description |
|---|---|
run | Start program |
break | Set breakpoint |
continue | Continue execution |
step | Step into |
next | Step over |
print | Print variable |
backtrace | Show stack |
frame | Select frame |
info locals | Show local variables |
info args | Show function arguments |
watch | Set watchpoint |
list | Show source code |
disassemble | Show assembly |
quit | Exit GDB |
Keyboard Shortcuts
| Key | Action |
|---|---|
Ctrl+C | Interrupt program |
Ctrl+D | Exit GDB |
Enter | Repeat last command |
Ctrl+X A | Toggle TUI mode |
Ctrl+L | Refresh screen |
Ctrl+P | Previous command |
Ctrl+N | Next command |
GDB is an indispensable tool for debugging compiled programs, offering powerful features for understanding program behavior, finding bugs, and analyzing crashes.
Core Dump Analysis
Core dumps are memory snapshots of a process at the moment it crashed, essential for post-mortem debugging.
Enable Core Dumps
# Check current limit
ulimit -c
# Enable unlimited core dumps
ulimit -c unlimited
# Make persistent (add to ~/.bashrc)
echo "ulimit -c unlimited" >> ~/.bashrc
# System-wide core dump configuration
sudo vim /etc/security/limits.conf
# Add: * soft core unlimited
Configure Core Dump Location
# Set core dump pattern
sudo sysctl -w kernel.core_pattern=/tmp/core-%e-%p-%t
# Options:
# %e - executable name
# %p - PID
# %t - timestamp
# %s - signal number
# %h - hostname
# Or use systemd-coredump
sudo sysctl -w kernel.core_pattern=|/lib/systemd/systemd-coredump %P %u %g %s %t %c %h
Generate Test Core Dump
# From running process
kill -SEGV <pid>
# From code
#include <signal.h>
raise(SIGSEGV);
# Trigger with gdb
gdb ./program
(gdb) run
(gdb) generate-core-file
Analyze Core Dump with GDB
# Load core dump
gdb ./program core
# Or
gdb ./program core.12345
# GDB commands
(gdb) bt # Backtrace
(gdb) info threads # List threads
(gdb) thread 2 # Switch to thread 2
(gdb) frame 0 # Select frame
(gdb) info locals # Local variables
(gdb) print variable # Print variable
(gdb) info registers # CPU registers
(gdb) disassemble # Disassemble current function
Example Analysis Session
$ gdb ./myapp core.12345
(gdb) bt
#0 0x00007f8b9c5a7428 in __GI_raise ()
#1 0x00007f8b9c5a902a in __GI_abort ()
#2 0x0000000000401234 in my_function () at myapp.c:42
#3 0x0000000000401567 in main () at myapp.c:100
(gdb) frame 2
(gdb) list
37 int *ptr = NULL;
38 int value = 0;
39
40 // This will crash
41 value = *ptr;
42
43 return value;
(gdb) print ptr
$1 = (int *) 0x0
(gdb) info locals
ptr = 0x0
value = 0
Extract Information
# File information
file core.12345
# Strings in core
strings core.12345 | less
# Binary that produced core
file core.12345
# Look for "execfn:" in output
# All loaded libraries
gdb -batch -ex "info sharedlibrary" ./program core
Automated Analysis
# Generate backtrace
gdb -batch -ex "bt" ./program core > backtrace.txt
# All threads backtrace
gdb -batch -ex "thread apply all bt" ./program core > all_threads.txt
Core Dump with Containers
# Docker - enable core dumps
docker run --ulimit core=-1 myimage
# Kubernetes - configure pod
spec:
containers:
- name: myapp
resources:
limits:
core: "-1"
Best Practices
- Always compile with debug symbols:
gcc -g - Keep matching binaries for core analysis
- Configure appropriate core dump location
- Set reasonable ulimit to prevent disk filling
- Use systemd-coredump for centralized management
- Strip production binaries but keep debug symbols separate
Core dumps are invaluable for debugging crashes in production systems.
Linux Kernel Debugging
Debugging the Linux kernel requires specialized tools and techniques due to its low-level nature.
Kernel Log (dmesg)
# View kernel messages
dmesg
# Follow kernel log
dmesg -w
dmesg --follow
# Filter by level
dmesg -l err,warn
# Human-readable timestamps
dmesg -T
# Clear ring buffer
sudo dmesg -C
Kernel Parameters
# View boot parameters
cat /proc/cmdline
# Add debug parameters (GRUB)
# Edit /etc/default/grub
GRUB_CMDLINE_LINUX="... debug ignore_loglevel"
# Update GRUB
sudo update-grub
printk Debugging
// In kernel code
#include <linux/printk.h>
printk(KERN_INFO "Debug: value = %d\n", value);
printk(KERN_ERR "Error occurred\n");
// Log levels
KERN_EMERG, KERN_ALERT, KERN_CRIT, KERN_ERR,
KERN_WARNING, KERN_NOTICE, KERN_INFO, KERN_DEBUG
KGDB (Kernel Debugger)
# Kernel configuration
CONFIG_KGDB=y
CONFIG_KGDB_SERIAL_CONSOLE=y
# Boot with kgdb
kgdboc=ttyS0,115200 kgdbwait
# Connect with GDB
gdb ./vmlinux
(gdb) target remote /dev/ttyS0
(gdb) continue
Kernel Oops Analysis
# When kernel oops occurs, check dmesg
dmesg | tail -100
# Decode with scripts
./scripts/decode_stacktrace.sh vmlinux < oops.txt
# addr2line for addresses
addr2line -e vmlinux -f -i 0xffffffffc0123456
SystemTap
# Install
sudo apt install systemtap
# Simple script
stap -e 'probe kernel.function("sys_open") { println("open called") }'
# Trace system calls
stap -e 'probe syscall.* { printf("%s\n", name) }'
ftrace
# Enable function tracing
cd /sys/kernel/debug/tracing
echo function > current_tracer
echo 1 > tracing_on
cat trace
# Trace specific function
echo sys_open > set_ftrace_filter
echo function > current_tracer
# Disable
echo 0 > tracing_on
Kernel Crash Dumps (kdump)
# Install kdump
sudo apt install kdump-tools
# Configure /etc/default/kdump-tools
USE_KDUMP=1
# Test
echo c | sudo tee /proc/sysrq-trigger
# Analyze with crash
crash /usr/lib/debug/boot/vmlinux-$(uname -r) /var/crash/*/dump.*
Kernel debugging requires patience and specialized knowledge, but these tools make it manageable.
Miscellaneous: Mathematical Foundations
Essential mathematical and statistical concepts with intuitive explanations for engineers, scientists, and technical professionals.
What's in This Section
This section contains foundational quantitative knowledge that underpins computer science, data science, engineering, and scientific computing:
📐 Mathematics
Comprehensive calculus guide with deep intuitive explanations covering:
- Limits and Continuity - Foundations of analysis
- Derivatives - Measuring instantaneous change
- Differentiation Techniques - Product rule, chain rule, implicit differentiation
- Integration - Accumulation and area under curves
- Integration Techniques - Substitution, integration by parts
- Sequences and Series - Infinite processes
- Multivariable Calculus - Partial derivatives, gradients
- Differential Equations - Modeling dynamic systems
890+ lines of content with:
- Intuitive explanations before formulas
- Visual analogies and mental models
- Real-world applications
- "Why it works" insights
- Common misconceptions addressed
📊 Statistics
Practical statistics guide focused on real-world applications:
- Descriptive Statistics - Mean, median, mode, when to use each
- Percentiles & Quantiles - p50, p90, p95, p99 deeply explained
- Variance & Standard Deviation - Measuring spread
- Probability Distributions - Normal, exponential, Poisson, long-tail
- Probability Basics - Conditional probability, Bayes' Theorem
- Statistical Inference - Confidence intervals, p-values, hypothesis testing
- Correlation & Regression - Correlation ≠ Causation
- Real-World Applications - Performance monitoring, A/B testing, reliability
900+ lines of content with:
- Software engineering focus
- SRE/DevOps examples
- Tail latency explained
- Percentiles for performance monitoring
- Common statistical pitfalls
📈 Matplotlib
Complete data visualization guide for Python:
- Architecture & Core Concepts - Figure, Axes, Artists hierarchy
- Basic Plotting - Line plots, scatter plots, bar charts
- Customization - Colors, styles, labels, legends, annotations
- Advanced Plot Types - Subplots, 3D plots, contours, heatmaps
- ML/Data Science Visualizations - Loss curves, confusion matrices, feature distributions
- Styling and Themes - Seaborn integration, custom styles
- Animations - Dynamic visualizations
- Performance & Best Practices - Efficient plotting for large datasets
Comprehensive guide with:
- Publication-quality visualizations
- Two interfaces: pyplot vs object-oriented
- Machine learning focused examples
- Integration patterns with NumPy, Pandas, Seaborn
- Common patterns and recipes
How These Topics Relate
Mathematics: The Theory
- What: Calculus and mathematical analysis
- When: Understanding change, optimization, modeling continuous systems
- For: Algorithm analysis, machine learning foundations, physics simulations
- Key Concepts: Derivatives, integrals, differential equations
Statistics: The Practice
- What: Data analysis and quantifying uncertainty
- When: Making decisions from data, monitoring systems, testing hypotheses
- For: Performance monitoring, A/B testing, capacity planning, reliability engineering
- Key Concepts: Percentiles, distributions, inference, correlation
The Connection
Calculus provides the continuous mathematics:
- How things change (derivatives)
- How to accumulate (integrals)
- Optimization (finding extrema)
- Modeling dynamics (differential equations)
Statistics provides the discrete/probabilistic mathematics:
- How to summarize data (descriptive statistics)
- How to quantify uncertainty (probability)
- How to make inferences (statistical inference)
- How to find relationships (correlation, regression)
Together, they form the quantitative foundation for:
- Machine Learning: Optimization (calculus) + probability (statistics)
- System Monitoring: Continuous metrics (calculus) + percentiles (statistics)
- Algorithm Analysis: Continuous complexity (calculus) + average case (statistics)
- Scientific Computing: Modeling (calculus) + uncertainty quantification (statistics)
Quick Reference Guide
When to Use Mathematics
Optimization Problems:
- Minimize cost, maximize profit
- Find critical points with derivatives
- Example: "What dimensions minimize material for a box of given volume?"
Rates of Change:
- Velocity, acceleration, growth rates
- Use derivatives
- Example: "How fast is temperature changing at this moment?"
Accumulation:
- Total distance from velocity
- Area under curve
- Use integration
- Example: "What's the total energy consumed over time?"
Modeling Dynamics:
- Systems that evolve over time
- Use differential equations
- Example: "How does population grow with limited resources?"
When to Use Statistics
System Performance:
- API latency, request rates
- Use percentiles (p50, p90, p95, p99)
- Example: "What's our p99 latency?" (better than average)
A/B Testing:
- Does feature A perform better than B?
- Use hypothesis testing, confidence intervals
- Example: "Is the new UI improving conversions?"
Capacity Planning:
- How many servers needed?
- Use distributions, percentiles
- Example: "Provision for p99 traffic, not average"
Reliability Engineering:
- Failure rates, uptime
- Use exponential distribution, MTBF
- Example: "What's our expected availability?"
Data Analysis:
- Understanding patterns in data
- Use descriptive statistics, visualization
- Example: "Why is median different from mean?"
Learning Path
For Software Engineers
Start with Statistics:
- Percentiles - Critical for performance monitoring
- Descriptive Statistics - Mean vs median
- Probability Basics - Understanding randomness
- Real-World Applications - SRE/DevOps examples
Then Mathematics:
- Derivatives - For understanding optimization
- Integration - For accumulation problems
- Limits - Foundational concepts
For Data Scientists
Start with Both:
- Statistics - Inference - Hypothesis testing
- Statistics - Correlation - Relationships
- Mathematics - Multivariable Calculus - Gradients
- Mathematics - Optimization - Finding extrema
For Machine Learning Engineers
Focused Path:
- Multivariable Calculus - Gradients for backpropagation
- Probability Distributions - Understanding data
- Optimization - Gradient descent
- Statistical Inference - Model evaluation
For System Reliability Engineers
Performance-Focused Path:
- Percentiles - p99 latency monitoring
- Distributions - Long-tail behavior
- Reliability Applications - MTBF, availability
- Probability - Failure rates
Key Takeaways
From Mathematics
- Derivatives measure instantaneous change - velocity, acceleration, sensitivity
- Integration is accumulation - total from rate, area under curve
- Optimization finds best values - where derivative equals zero
- Differential equations model dynamics - how systems evolve
From Statistics
- Mean hides outliers - use median or percentiles instead
- p99 matters at scale - 1% of 1M requests = 10,000 users
- Correlation ≠ Causation - relationships don't imply cause
- Percentiles reveal user experience - p50/p90/p95/p99 tell full story
- Variance matters - same mean, different experiences
Common Questions
Q: When do I need calculus vs statistics?
- Calculus: Continuous change, optimization, modeling dynamics
- Statistics: Data analysis, uncertainty, making decisions from samples
Q: Why are percentiles emphasized in statistics.md?
- In software systems, averages hide the worst-case experience
- p99 latency affects thousands of users at scale
- SLAs should use percentiles, not averages
Q: Do I need to master both?
- For software engineering: Statistics more immediately practical
- For ML/AI: Both essential (calculus for optimization, statistics for data)
- For system design: Statistics for monitoring, calculus for modeling
Q: What about linear algebra?
- Critical for ML but not yet in this section
- Complements both calculus and statistics
- Consider adding matrix operations, eigenvalues, SVD
Practical Wisdom
For Monitoring:
Always track: p50, p90, p95, p99
Alert on: p99 degradation
SLA: "p95 < 100ms" (not "average < 100ms")
For Optimization:
Find critical points: f'(x) = 0
Check second derivative: f''(x) > 0 → minimum
Verify constraints: boundaries matter
For Testing:
Sample size matters: larger → more confidence
Statistical significance ≠ practical significance
Report confidence intervals, not just p-values
For Capacity Planning:
Provision for p99, not average
Account for traffic growth (2x-3x)
Add headroom (1.5x-2x buffer)
Load test at target capacity
Further Learning
Books
- Calculus: "Calculus Made Easy" by Silvanus P. Thompson
- Statistics: "The Art of Statistics" by David Spiegelhalter
- Both: "Mathematics for Machine Learning" by Deisenroth et al.
Online Resources
- 3Blue1Brown (YouTube): Visualized calculus and linear algebra
- StatQuest (YouTube): Statistics and ML explained simply
- Khan Academy: Comprehensive math and statistics courses
Practice
- LeetCode: Apply math to algorithmic problems
- Kaggle: Apply statistics to real datasets
- Real Systems: Monitor your own services with percentiles
Contributing
Both documents are living resources. If you find:
- Errors or unclear explanations: Please report
- Missing concepts: Suggest additions
- Better intuitive explanations: Share them
- Real-world examples: We love practical applications
The goal is to make quantitative reasoning accessible and practical for technical professionals.
Last Updated: December 2024 Maintained By: Technical Knowledge Base Contributors
Fundamental Mathematical Concepts
A comprehensive guide to calculus and essential mathematical concepts.
Table of Contents
- Limits and Continuity
- Derivatives
- Differentiation Techniques
- Applications of Derivatives
- Integration
- Integration Techniques
- Applications of Integration
- Sequences and Series
- Multivariable Calculus
- Differential Equations
Limits and Continuity
Intuition: What Limits Really Mean
The Core Idea: A limit is about prediction, not arrival. It answers: "If I get arbitrarily close to a point, where is my function heading?" You care about the journey, not the destination.
Why Limits Matter: Real-world processes approach values without reaching them. A ball rolling toward a stop, a population approaching carrying capacity, an asymptote you'll never touch—limits capture this "tendency toward" behavior.
The Key Insight: The limit at a point can exist even if:
- The function isn't defined there (removable discontinuity)
- The function value is different from the limit (jump)
- You can never actually reach that point (approaching infinity)
Mental Model: Imagine walking toward a door. You can get 1 meter away, then 0.5m, then 0.25m, then 0.125m... You keep halving the distance. The limit is the door itself, even though in this infinite process you never quite touch it. That's the essence of a limit—where you're heading, not where you are.
Definition of a Limit
The limit of a function f(x) as x approaches a is L, written as:
lim(x�a) f(x) = L
Formal (�-�) Definition: For every � > 0, there exists a � > 0 such that if 0 < |x - a| < �, then |f(x) - L| < �.
Intuitive Definition: As x gets arbitrarily close to a (but not equal to a), f(x) gets arbitrarily close to L.
Limit Laws
If lim(x�a) f(x) = L and lim(x�a) g(x) = M, then:
- Sum Rule: lim(x�a) [f(x) + g(x)] = L + M
- Difference Rule: lim(x�a) [f(x) - g(x)] = L - M
- Product Rule: lim(x�a) [f(x) � g(x)] = L � M
- Quotient Rule: lim(x�a) [f(x) / g(x)] = L / M (if M ` 0)
- Constant Multiple: lim(x�a) [c � f(x)] = c � L
- Power Rule: lim(x�a) [f(x)]^n = L^n
Types of Limits
One-Sided Limits:
- Right-hand limit: lim(x�az) f(x)
- Left-hand limit: lim(x�a{) f(x)
- A limit exists if and only if both one-sided limits exist and are equal
Infinite Limits:
- lim(x�a) f(x) = (function grows without bound)
- lim(x�a) f(x) = - (function decreases without bound)
Limits at Infinity:
- lim(x�) f(x) = L
- lim(x�-) f(x) = L
Continuity
A function f is continuous at x = a if:
- f(a) is defined
- lim(x�a) f(x) exists
- lim(x�a) f(x) = f(a)
Intuition: The Pencil Test: A function is continuous if you can draw its graph without lifting your pencil. No jumps, no holes, no breaks. Continuity means "no surprises"—small changes in input give small changes in output.
Why Three Conditions?
- Function must be defined: You need a value at the point (no hole)
- Limit must exist: Left and right approaches agree (no jump)
- They must match: Where you're going equals where you are (no removable discontinuity)
Real-World Connection: Temperature changes continuously through the day. You don't instantly jump from 20°C to 25°C. But a light switch has a discontinuity—it's OFF then suddenly ON, no in-between.
Types of Discontinuity:
- Removable: Limit exists but f(a) is undefined or different
- Jump: Left and right limits exist but are unequal
- Infinite: Function approaches �
Important Theorems:
-
Intermediate Value Theorem (IVT): If f is continuous on [a,b] and k is between f(a) and f(b), then there exists c in (a,b) such that f(c) = k
Intuition: If you walk up a mountain continuously from elevation 100m to 300m, you must cross through 200m at some point. Continuous functions can't "skip" values. This is why roots exist—if f(a) < 0 and f(b) > 0, the function must cross zero somewhere between.
-
Extreme Value Theorem (EVT): A continuous function on a closed interval [a,b] attains both a maximum and minimum value
Intuition: On a closed, bounded hike, there's a highest point and lowest point. You can't have a highest point if the path goes to infinity (unbounded) or if there's a discontinuous jump (function not continuous). Both continuity and closed interval are essential.
Derivatives
Intuition: Measuring Instantaneous Change
The Central Question: How fast is something changing right now?
The Problem: We can easily calculate average change (rise over run), but how do we measure change at a single instant? There's no "run" at a point—it's just one location.
The Brilliant Solution: Get closer and closer to the instant. Make the time interval smaller and smaller. The derivative is what that average rate approaches as the interval shrinks to zero.
Why the Limit Definition?
f'(a) = lim(h�0) [f(a+h) - f(a)] / h
- f(a+h) - f(a): Change in output (rise)
- h: Change in input (run)
- Ratio: Average rate of change
- As h→0: Average becomes instantaneous
Visual Intuition: Draw a curve. Put two points close together and connect them with a line (secant). Now move the second point closer... closer... closer. That secant line becomes the tangent line. Its slope is the derivative.
Three Ways to Think About Derivatives:
- Geometric: Slope of the tangent line (best linear approximation)
- Physical: Instantaneous rate of change (velocity from position)
- Algebraic: Ratio of infinitesimal changes (dy/dx)
Real-World Power: The derivative lets us answer:
- How fast is the rocket accelerating right now?
- At what rate is the population growing at this instant?
- How sensitive is profit to a price change at this price point?
The Magic: Even though we can't divide by zero, limits let us see what "would happen" if we could. That's the derivative—the impossible made possible.
Definition
The derivative of f(x) at x = a is:
f'(a) = lim(h�0) [f(a+h) - f(a)] / h
Alternative form:
f'(a) = lim(x�a) [f(x) - f(a)] / (x - a)
Interpretations
Geometric: The derivative represents the slope of the tangent line to the curve at a point.
Physical: The derivative represents the instantaneous rate of change.
- If s(t) is position, then s'(t) is velocity
- If v(t) is velocity, then v'(t) is acceleration
Notation
Multiple notations for derivatives:
- Lagrange: f'(x), f''(x), f'''(x), f}~(x)
- Leibniz: dy/dx, d�y/dx�, dy/dx
- Newton: �, � (for time derivatives)
- Euler: D_x f, D�_x f
Why So Many Notations?
- Lagrange's f'(x): Compact, emphasizes function
- Leibniz's dy/dx: Shows it's a ratio of changes, makes chain rule intuitive, great for manipulation
- Newton's ẋ: Perfect for physics where time is the variable
- Euler's D_x: Emphasizes the operator view (differentiation is an operation)
Each notation highlights a different aspect. Leibniz notation (dy/dx) is especially powerful because it reminds us that derivatives are ratios—even though dy and dx aren't real numbers, they behave algebraically like fractions in many contexts.
Basic Derivative Rules
-
Constant Rule: d/dx[c] = 0
- Intuition: Constants don't change. Derivative measures change, so zero change means zero derivative.
-
Power Rule: d/dx[x^n] = n�x^(n-1)
- Intuition: The power comes down as a multiplier, and the degree drops by one. Why? When you increase x slightly, x^n grows proportionally to n times the previous value. This is the pattern of exponential-like growth encoded in powers.
-
Constant Multiple: d/dx[c�f(x)] = c�f'(x)
- Intuition: Scaling doesn't change the rate pattern, just its magnitude. If f doubles, c·f doubles—same rate, scaled up.
-
Sum Rule: d/dx[f(x) + g(x)] = f'(x) + g'(x)
- Intuition: Changes add. If position is f+g, then velocity is f'+g'. Independent contributions to change sum linearly.
-
Difference Rule: d/dx[f(x) - g(x)] = f'(x) - g'(x)
- Intuition: Same as sum rule, but subtracting. The rate of change of a difference is the difference of rates.
Higher-Order Derivatives
- First derivative: f'(x) or dy/dx - rate of change
- Second derivative: f''(x) or d�y/dx� - rate of change of rate of change (concavity)
- Third derivative: f'''(x) or d�y/dx� - jerk (in physics)
- nth derivative: f}~(x) or dy/dx
Intuition for Higher Derivatives:
- First derivative (f'): The speedometer—how fast you're going
- Second derivative (f''): The accelerometer—how fast your speed is changing
- Third derivative (f'''): The "jerk meter"—how fast your acceleration is changing (why sudden braking feels jarring)
Why Second Derivatives Matter: They measure the curvature of change:
- f' tells you the slope
- f'' tells you if the slope is increasing or decreasing
- This reveals the shape of the curve
Concavity:
-
f''(x) > 0 � concave up (curve opens upward) - "holds water" - smiling face ∪ Meaning: Slope is increasing. The function is accelerating upward.
-
f''(x) < 0 � concave down (curve opens downward) - "spills water" - frowning face ∩ Meaning: Slope is decreasing. The function is accelerating downward.
-
f''(x) = 0 � possible inflection point Meaning: The curvature changes. Like the middle of an S-curve where the turn reverses.
Physical Intuition:
- Position → Velocity → Acceleration
- Cost → Marginal Cost → Rate of change of marginal cost
- Each derivative is "one level deeper" into understanding change
Differentiation Techniques
Product Rule
If u and v are differentiable functions:
d/dx[u�v] = u'�v + u�v'
Intuition: When two things multiply and both are changing, you get contributions from each:
- u'·v: Change in u, holding v constant
- u·v': Change in v, holding u constant
Think of area of a rectangle with changing width u and height v. The area changes in two ways: width changes (u' times v), and height changes (u times v'). Both contribute to how the total area changes.
Memory trick: "First times derivative of second, plus second times derivative of first"
Example: d/dx[x��sin(x)] = 2x�sin(x) + x��cos(x)
Quotient Rule
d/dx[u/v] = (u'�v - u�v') / v�
Intuition: A fraction changes when:
- Numerator increases: Fraction goes up → positive contribution (u'·v)
- Denominator increases: Fraction goes down → negative contribution (-u·v')
- Divide by v²: Normalize by the square of denominator
Why the minus sign? When the bottom gets bigger, the fraction gets smaller. That's the opposite (negative) effect.
Memory trick: "Low dee-high minus high dee-low, over the square of what's below"
- Low (v) × derivative of high (u')
- Minus high (u) × derivative of low (v')
- Over low squared (v²)
Pro tip: Often easier to rewrite as u·v⁻¹ and use product rule + chain rule!
Example: d/dx[sin(x)/x] = [x�cos(x) - sin(x)] / x�
Chain Rule
For composite functions f(g(x)):
d/dx[f(g(x))] = f'(g(x))�g'(x)
Or in Leibniz notation:
dy/dx = (dy/du)�(du/dx)
Intuition: Nested Change
The chain rule captures how change propagates through nested functions. It's the mathematical expression of cause-and-effect chains.
The Principle: If A affects B, and B affects C, then A's effect on C is the product of:
- How much B changes when A changes (inner derivative)
- How much C changes when B changes (outer derivative)
Why Multiply? Changes compound multiplicatively through composition:
- If x changes by small amount dx
- Then g(x) changes by approximately g'(x)·dx
- Then f(g(x)) changes by approximately f'(g(x))·[g'(x)·dx]
- So the total rate is f'(g(x))·g'(x)
Leibniz notation magic: dy/dx = (dy/du)·(du/dx) looks like fractions canceling! While not rigorous, it's a powerful mnemonic and often works algebraically.
Visual: Imagine zooming through nested magnifications. Each layer magnifies by its derivative. Total magnification is the product of all layers.
Real-World Example:
- Distance depends on time: d = f(t)
- Time depends on temperature: t = g(T)
- How does distance change with temperature? dd/dT = (dd/dt)·(dt/dT)
- Chain rule connects indirect relationships!
Example: d/dx[sin(x�)] = cos(x�)�2x = 2x�cos(x�)
- Outer function: sin(u) → derivative is cos(u)
- Inner function: u = x² → derivative is 2x
- Evaluate outer derivative at inner function: cos(x²)
- Multiply by inner derivative: cos(x²)·2x
Implicit Differentiation
When a relation is given implicitly (not solved for y):
Steps:
- Differentiate both sides with respect to x
- Apply chain rule to terms with y (multiply by dy/dx)
- Solve for dy/dx
Intuition: Sometimes you can't (or don't want to) solve for y explicitly. No problem! Differentiate the relationship itself.
Key Insight: y is a function of x, even if we haven't written y = f(x). So when differentiating y terms, use the chain rule—y's derivative with respect to x is dy/dx (which we're solving for).
Why It Works: The equation defines a relationship. Differentiation preserves that relationship. Both sides must change at the same rate to maintain the equation.
Mental Model: Think of x and y as linked by a constraint. When x changes, y must change in a specific way to keep the constraint satisfied. Implicit differentiation finds that required rate.
Example: x� + y� = 25 (circle equation)
2x + 2y�(dy/dx) = 0
dy/dx = -x/y
Interpretation: At any point on the circle, the slope is -x/y. This is the tangent to the circle!
Logarithmic Differentiation
Useful for products, quotients, and powers of functions:
Steps:
- Take ln of both sides
- Use logarithm properties to simplify
- Differentiate implicitly
- Solve for dy/dx
Intuition: Logarithms convert multiplication to addition, division to subtraction, and powers to multiplication. This transforms messy products/quotients/powers into simple sums/differences.
Why Take ln? Logarithms are the perfect tool for:
- Products: ln(ab) = ln(a) + ln(b) → sum rule instead of product rule
- Quotients: ln(a/b) = ln(a) - ln(b) → difference instead of quotient rule
- Powers: ln(a^b) = b·ln(a) → brings exponents down as multipliers
When to Use:
- Variable in both base and exponent (x^x)
- Complicated products of many functions
- Complicated quotients
- Functions raised to function powers
The Magic: ln converts complex derivative rules into simple arithmetic!
Example: y = x^x (variable base and exponent!)
ln(y) = x�ln(x)
(1/y)�(dy/dx) = ln(x) + 1
dy/dx = y�(ln(x) + 1) = x^x�(ln(x) + 1)
Why it works: Without ln, we'd struggle with x^x (power rule needs constant exponent, exponential rule needs constant base). Logarithm untangles it!
Parametric Differentiation
For curves defined parametrically: x = f(t), y = g(t)
dy/dx = (dy/dt) / (dx/dt)
Second derivative:
d�y/dx� = d/dx[dy/dx] = [d/dt(dy/dx)] / (dx/dt)
Common Derivatives
Trigonometric Functions:
- d/dx[sin(x)] = cos(x)
- d/dx[cos(x)] = -sin(x)
- d/dx[tan(x)] = sec�(x)
- d/dx[cot(x)] = -csc�(x)
- d/dx[sec(x)] = sec(x)�tan(x)
- d/dx[csc(x)] = -csc(x)�cot(x)
Inverse Trigonometric Functions:
- d/dx[arcsin(x)] = 1/(1-x�)
- d/dx[arccos(x)] = -1/(1-x�)
- d/dx[arctan(x)] = 1/(1+x�)
Exponential and Logarithmic Functions:
- d/dx[e^x] = e^x
- d/dx[a^x] = a^x�ln(a)
- d/dx[ln(x)] = 1/x
- d/dx[log_a(x)] = 1/(x�ln(a))
Hyperbolic Functions:
- d/dx[sinh(x)] = cosh(x)
- d/dx[cosh(x)] = sinh(x)
- d/dx[tanh(x)] = sech�(x)
Applications of Derivatives
Critical Points and Extrema
Critical Point: x = c where f'(c) = 0 or f'(c) does not exist
Intuition: Finding the Best
Why Derivative = 0? At a peak or valley, the slope is horizontal (neither going up nor down). That's where f'(x) = 0. It's a moment of transition—the function stops increasing and starts decreasing (or vice versa).
The Physical Picture:
- Imagine hiking on a mountain path
- At the top of a hill: you stop going up and start going down → slope = 0 → local max
- At the bottom of a valley: you stop going down and start going up → slope = 0 → local min
- Critical points are potential peaks and valleys
Why Also Check Where f' Doesn't Exist? Sharp corners and cusps can be extrema even without f' = 0. Think of a spike—it's a maximum even though there's no horizontal tangent.
Finding Extrema:
- Find all critical points
- Use First Derivative Test or Second Derivative Test
- Check endpoints (for closed intervals)
First Derivative Test (Sign Analysis):
- If f' changes from + to - at c, then f has a local maximum at c Intuition: Function rises then falls → peak
- If f' changes from - to + at c, then f has a local minimum at c Intuition: Function falls then rises → valley
Second Derivative Test (Concavity):
- If f'(c) = 0 and f''(c) > 0, then f has a local minimum at c Intuition: Concave up (∪ shape) + horizontal tangent → bottom of bowl
- If f'(c) = 0 and f''(c) < 0, then f has a local maximum at c Intuition: Concave down (∩ shape) + horizontal tangent → top of dome
- If f''(c) = 0, test is inconclusive Intuition: Could be inflection point, not extremum
Optimization Problems
Intuition: Finding the Best in Real Life
Optimization is about making the best choice given constraints. Maximum profit, minimum cost, shortest distance, largest area—these are all optimization problems.
The Key Insight: "Best" happens where you can't improve by making small changes. That's exactly where the derivative is zero—tiny changes don't help (first-order improvement is zero).
Real-World Examples:
- Farmer: What dimensions maximize area with fixed fence length?
- Company: What price maximizes profit?
- Engineer: What design minimizes material while meeting strength requirements?
Why Constraints Matter: They reduce freedom. With constraints, you can eliminate variables and reduce to a one-variable optimization problem that calculus can solve.
General Strategy:
- Identify the quantity to optimize (write as a function)
- Identify constraints
- Use constraints to express the quantity as a function of one variable
- Find critical points
- Determine which critical point gives the optimal value
Pro Tip: Always check endpoints and boundaries. Sometimes the best solution is at an extreme constraint, not at a critical point.
Related Rates
For quantities that change with respect to time:
Intuition: Everything is Connected
Related rates problems capture how changes in one quantity affect another when they're linked by a relationship. It's the mathematics of interconnected change.
The Core Idea: If two variables are related by an equation, their rates of change are also related. Differentiate the relationship to find how rates connect.
Why "Related"? When x and y satisfy an equation, they're not independent. As x changes, y must change in a compatible way. Their rates of change (dx/dt and dy/dt) are thus related through the same equation structure.
Real-World Examples:
- Balloon inflating: radius grows → volume grows (but at what rate?)
- Shadow lengthening: person walks → shadow extends (how fast?)
- Water draining: height drops → volume drops (connection?)
- Ladder sliding: bottom slides out → top slides down (how are these rates related?)
The Process:
- Identify the relationship between variables (geometric or physical)
- Differentiate the entire relationship with respect to time
- The result links the rates of change
Strategy:
- Draw a diagram and label variables
- Write an equation relating the variables
- Differentiate both sides with respect to time t
- Substitute known values
- Solve for the desired rate
Example: A ladder sliding down a wall
x� + y� = L�
2x�(dx/dt) + 2y�(dy/dt) = 0
Interpretation: As bottom moves out (dx/dt), top must move down (dy/dt) to maintain constant ladder length L. The rates are inversely related through the geometry.
Mean Value Theorem (MVT)
If f is continuous on [a,b] and differentiable on (a,b), then there exists c in (a,b) such that:
f'(c) = [f(b) - f(a)] / (b - a)
Interpretation: There exists a point where the instantaneous rate equals the average rate.
Linear Approximation
The tangent line approximation at x = a:
L(x) = f(a) + f'(a)�(x - a)
For small �x:
f(a + �x) H f(a) + f'(a)��x
Differentials:
- dx = �x (change in x)
- dy = f'(x)�dx (change in tangent line)
- �y = f(x + dx) - f(x) (actual change in f)
L'H�pital's Rule
For indeterminate forms 0/0 or /:
lim(x�a) [f(x)/g(x)] = lim(x�a) [f'(x)/g'(x)]
Can be applied repeatedly if result is still indeterminate.
Other indeterminate forms (0�, -, 0p, 1^, p) can be converted to 0/0 or / form.
Curve Sketching
Complete Analysis:
- Domain and range
- Intercepts (x and y)
- Symmetry (even, odd, periodic)
- Asymptotes (vertical, horizontal, oblique)
- First derivative (increasing/decreasing, local extrema)
- Second derivative (concavity, inflection points)
- Plot key points and sketch
Integration
Intuition: Accumulation and Reverse Engineering
The Big Picture: Integration is about accumulation—adding up infinitely many infinitesimally small pieces. It's also the reverse of differentiation.
Two Perspectives on Integration:
-
Geometric (Area/Accumulation):
- Slice a region into infinitely thin rectangles
- Add up their areas: height f(x) times width dx
- As rectangles get infinitesimally thin, sum becomes integral
- Result: area under curve
-
Algebraic (Antiderivative):
- Derivative breaks things apart (rate of change)
- Integral builds things back up (accumulation from rate)
- If F'(x) = f(x), then ∫f(x)dx = F(x) + C
- Integration "undoes" differentiation
Why Integration Matters: Whenever you know a rate and want the total:
- Know velocity → find displacement
- Know flow rate → find total volume
- Know marginal cost → find total cost
- Know rate of growth → find population
The Fundamental Question: Given how fast something is changing (derivative), what is the thing itself (original function)?
Why the dx? It's not just notation—it represents an infinitesimal width. The integral is literally a sum: ∫f(x)dx = "sum of f(x) times infinitesimal dx pieces". Think of it as lim(Δx→0) Σf(x)Δx.
The "+ C" Mystery: When you differentiate, constants vanish (derivative of constant = 0). So when you integrate (reverse), you can't know what constant was there. Could be any C!
Antiderivatives
An antiderivative of f(x) is a function F(x) such that F'(x) = f(x).
General Antiderivative: F(x) + C, where C is an arbitrary constant.
Indefinite Integrals
The indefinite integral represents the family of all antiderivatives:
+ f(x) dx = F(x) + C
Definite Integrals
The definite integral from a to b:
+[a to b] f(x) dx
Geometric Interpretation: The signed area between the curve and the x-axis from a to b.
Properties:
- +[a to b] c�f(x) dx = c�+[a to b] f(x) dx
- +[a to b] [f(x) � g(x)] dx = +[a to b] f(x) dx � +[a to b] g(x) dx
- +[a to b] f(x) dx = -+[b to a] f(x) dx
- +[a to a] f(x) dx = 0
- +[a to b] f(x) dx + +[b to c] f(x) dx = +[a to c] f(x) dx
Fundamental Theorem of Calculus
The Most Important Theorem in Calculus
This theorem is the bridge connecting derivatives and integrals—two concepts that seem completely different but are actually inverse operations.
Part 1: If f is continuous on [a,b] and F(x) = +[a to x] f(t) dt, then F'(x) = f(x).
Intuition for Part 1:
- F(x) = accumulated area from a to x
- When you increase x slightly to x + dx, you add a thin rectangle of area ≈ f(x)·dx
- Rate of change of accumulated area = height of function
- Profound Insight: Accumulating f gives you something whose rate of change is f. Integration and differentiation are inverses!
Analogy: If f(t) is your speedometer reading and F(x) is your odometer, then:
- Odometer accumulates distance: F(x) = ∫ speed
- Speedometer is rate of distance change: f(x) = F'(x)
- They're inverses of each other!
Part 2: If f is continuous on [a,b] and F is any antiderivative of f, then:
+[a to b] f(x) dx = F(b) - F(a)
Intuition for Part 2:
- Want to find area under curve from a to b
- Instead of summing infinitely many rectangles (hard!)
- Just find ANY function F whose derivative is f
- Evaluate F at endpoints and subtract: F(b) - F(a)
- This is miraculous: Infinite sum reduced to two function evaluations!
Why It Works:
- F tracks cumulative change
- F(b) = total accumulated from start to b
- F(a) = total accumulated from start to a
- F(b) - F(a) = accumulated from a to b
- That's exactly the integral!
The Power: This theorem transforms an infinitely complex problem (summing infinite pieces) into simple algebra (evaluate, subtract). It's why calculus is so powerful!
Historical Note: Newton and Leibniz's great insight wasn't derivatives or integrals separately—many knew about those. The breakthrough was realizing they're inverses (this theorem). That unified calculus and unlocked its power.
Basic Integration Formulas
-
- k dx = kx + C
-
- x^n dx = x^(n+1)/(n+1) + C (n ` -1)
-
- (1/x) dx = ln|x| + C
-
- e^x dx = e^x + C
-
- a^x dx = a^x/ln(a) + C
-
- sin(x) dx = -cos(x) + C
-
- cos(x) dx = sin(x) + C
-
- sec�(x) dx = tan(x) + C
-
- csc�(x) dx = -cot(x) + C
-
- sec(x)tan(x) dx = sec(x) + C
-
- csc(x)cot(x) dx = -csc(x) + C
-
- 1/(1-x�) dx = arcsin(x) + C
-
- 1/(1+x�) dx = arctan(x) + C
Riemann Sums
The definite integral is the limit of Riemann sums:
+[a to b] f(x) dx = lim(n�) �[i=1 to n] f(x_i*)��x
where �x = (b-a)/n and x_i* is a sample point in the ith subinterval.
Types:
- Left Riemann Sum: Use left endpoints
- Right Riemann Sum: Use right endpoints
- Midpoint Rule: Use midpoints
- Trapezoidal Rule: Average of left and right
- Simpson's Rule: Uses parabolic approximation
Integration Techniques
Substitution (u-Substitution)
Method: Let u = g(x), then du = g'(x)dx
Intuition: Reverse Chain Rule
u-substitution is the integration version of the chain rule. It recognizes that your integrand came from a chain rule differentiation, and "undoes" it.
The Key Insight: If you see f(g(x))·g'(x), this came from differentiating F(g(x)) via chain rule:
- d/dx[F(g(x))] = F'(g(x))·g'(x) = f(g(x))·g'(x)
- So ∫f(g(x))·g'(x)dx = F(g(x)) + C
When to Use: Look for:
- A composite function f(g(x))
- Whose "inside function's" derivative g'(x) appears as a factor
- Pattern: ∫[stuff]'·[function of stuff] → substitute u = stuff
Why It Works: The du = g'(x)dx substitution absorbs the chain rule's g'(x) term, reducing the composite function to a simple function of u.
Mental Model: You're "peeling off" the outer layer of composition. The integral becomes simpler in terms of the inner function.
The Art: Choosing the right u. Look for the "inner function" whose derivative (or a multiple) appears elsewhere in the integrand.
Steps:
- Choose substitution u = g(x)
- Calculate du = g'(x)dx
- Rewrite integral in terms of u
- Integrate with respect to u
- Substitute back to get result in terms of x
Example:
+ 2x�cos(x�) dx
Let u = x�, du = 2x dx
= + cos(u) du
= sin(u) + C
= sin(x�) + C
For definite integrals, also change the limits:
- If u = g(x), new limits are u = g(a) and u = g(b)
Integration by Parts
Formula:
+ u dv = uv - + v du
Intuition: Reverse Product Rule
Integration by parts is the integration version of the product rule. It trades one integral for another (hopefully simpler) integral.
The Core Idea:
- Product rule: (uv)' = u'v + uv'
- Rearrange: uv' = (uv)' - u'v
- Integrate both sides: ∫u(dv/dx)dx = uv - ∫v(du/dx)dx
- Or simply: ∫u dv = uv - ∫v du
When to Use: When integrand is a product of two different "types" of functions (polynomial × exponential, polynomial × trig, etc.)
The Strategy: Split the integrand into two parts:
- u: The part that gets simpler when differentiated
- dv: The part you can easily integrate
Why LIATE? This priority list ensures u gets simpler when you differentiate:
- Logarithmic → derivative is algebraic (simpler!)
- Inverse trig → derivative is algebraic (simpler!)
- Algebraic → derivative reduces power (simpler!)
- Trigonometric → derivative stays trig (no simpler)
- Exponential → derivative stays exponential (no simpler)
The Trade-Off: You're converting ∫u dv into uv - ∫v du. The goal is making the new integral ∫v du easier than the original.
Mental Model: You're "sacrificing" one factor (u) by differentiating it (hopefully simplifying it) while integrating the other (dv), then dealing with the resulting integral.
Pro Tip: Sometimes you need to apply integration by parts multiple times, or even in a cycle that allows you to solve for the original integral algebraically!
Choosing u and dv (LIATE rule):
- Logarithmic
- Inverse trigonometric
- Algebraic
- Trigonometric
- Exponential
Choose u in this order of preference; dv is what remains.
Example:
+ x�e^x dx
u = x, dv = e^x dx
du = dx, v = e^x
= x�e^x - + e^x dx
= x�e^x - e^x + C
= e^x(x - 1) + C
Tabular Integration: Efficient for repeated integration by parts.
Trigonometric Integrals
Strategies for + sin^m(x)cos^n(x) dx:
- If n is odd: Save one cos(x), convert rest to sin(x) using cos�(x) = 1 - sin�(x), then substitute u = sin(x)
- If m is odd: Save one sin(x), convert rest to cos(x) using sin�(x) = 1 - cos�(x), then substitute u = cos(x)
- If both are even: Use power-reducing formulas
- sin�(x) = (1 - cos(2x))/2
- cos�(x) = (1 + cos(2x))/2
Powers of tan and sec:
-
- tan^m(x)sec^n(x) dx
- Use tan�(x) = sec�(x) - 1 and sec�(x) derivative of tan(x)
Trigonometric Substitution
For integrals involving (a� - x�), (a� + x�), or (x� - a�):
-
(a� - x�): Let x = a�sin(�), dx = a�cos(�)d�
- (a� - x�) = a�cos(�)
-
(a� + x�): Let x = a�tan(�), dx = a�sec�(�)d�
- (a� + x�) = a�sec(�)
-
(x� - a�): Let x = a�sec(�), dx = a�sec(�)tan(�)d�
- (x� - a�) = a�tan(�)
Example:
+ (1 - x�) dx
Let x = sin(�), dx = cos(�)d�
= + cos(�)�cos(�) d�
= + cos�(�) d�
= + (1 + cos(2�))/2 d�
= �/2 + sin(2�)/4 + C
= arcsin(x)/2 + x(1-x�)/2 + C
Partial Fractions
For rational functions P(x)/Q(x) where degree(P) < degree(Q):
Steps:
- Factor the denominator Q(x)
- Decompose into partial fractions
- Solve for coefficients (equate coefficients or plug in values)
- Integrate each term
Forms:
- Linear factors: (x - a) � A/(x - a)
- Repeated linear: (x - a)^n � A�/(x-a) + A�/(x-a)� + ... + A�/(x-a)^n
- Quadratic factors: (x� + bx + c) � (Ax + B)/(x� + bx + c)
- Repeated quadratic: Similar to repeated linear
Example:
+ 1/(x� - 1) dx = + 1/[(x-1)(x+1)] dx
1/(x� - 1) = A/(x-1) + B/(x+1)
1 = A(x+1) + B(x-1)
Solving: A = 1/2, B = -1/2
= (1/2)+ 1/(x-1) dx - (1/2)+ 1/(x+1) dx
= (1/2)ln|x-1| - (1/2)ln|x+1| + C
= (1/2)ln|(x-1)/(x+1)| + C
Improper Integrals
Type 1: Infinite interval
+[a to ] f(x) dx = lim(t�) +[a to t] f(x) dx
Type 2: Discontinuous integrand
+[a to b] f(x) dx = lim(t�b{) +[a to t] f(x) dx (if f is discontinuous at b)
Convergence: The improper integral converges if the limit exists and is finite; otherwise it diverges.
Comparison Test: If 0 d f(x) d g(x) for x e a:
- If + g(x) dx converges, then + f(x) dx converges
- If + f(x) dx diverges, then + g(x) dx diverges
Applications of Integration
Area Between Curves
Vertical slicing (integrate with respect to x):
A = +[a to b] [f(x) - g(x)] dx
where f(x) e g(x) on [a,b]
Horizontal slicing (integrate with respect to y):
A = +[c to d] [f(y) - g(y)] dy
Volume
Disk Method (revolving around x-axis):
V = ��+[a to b] [f(x)]� dx
Washer Method (hollow solid):
V = ��+[a to b] [R(x)]� - [r(x)]� dx
where R(x) is outer radius, r(x) is inner radius
Shell Method (cylindrical shells):
V = 2��+[a to b] x�f(x) dx
or
V = 2��+[c to d] y�g(y) dy
Cross-Sectional Method:
V = +[a to b] A(x) dx
where A(x) is the area of cross-section at x
Arc Length
For y = f(x) on [a,b]:
L = +[a to b] (1 + [f'(x)]�) dx
For parametric curves x = f(t), y = g(t) on [�,�]:
L = +[� to �] ([dx/dt]� + [dy/dt]�) dt
For polar curves r = f(�):
L = +[� to �] (r� + [dr/d�]�) d�
Surface Area
Revolution around x-axis:
S = 2��+[a to b] f(x)�(1 + [f'(x)]�) dx
Revolution around y-axis:
S = 2��+[a to b] x�(1 + [f'(x)]�) dx
Work
Constant force: W = F�d
Variable force:
W = +[a to b] F(x) dx
Examples:
- Spring: W = + kx dx = (1/2)kx� (Hooke's Law)
- Lifting liquid: W = + ��g�A(y)�y dy
- Pumping: Account for distance each layer must be moved
Center of Mass
For a thin plate (lamina) with density �(x,y):
Mass:
m = ++_R �(x,y) dA
Moments:
M_x = ++_R y��(x,y) dA
M_y = ++_R x��(x,y) dA
Center of mass:
x = M_y / m
3 = M_x / m
For uniform density (� = constant), center of mass = centroid.
Sequences and Series
Intuition: The Mathematics of Infinity
The Fundamental Questions:
- Sequences: Where is this infinite list heading?
- Series: Can we add infinitely many numbers and get a finite answer?
These questions connect discrete (countable steps) with continuous (limits), and finite with infinite.
Sequences
A sequence is an ordered list: {a�, a�, a�, ...} or {a�}
Intuition: A sequence is a pattern that continues forever. Convergence asks: "Does this pattern settle down to a specific value, or does it keep wandering?"
Examples:
- {1, 1/2, 1/3, 1/4, ...} → converges to 0 (gets arbitrarily close)
- {1, -1, 1, -1, ...} → diverges (oscillates forever)
- {1, 2, 3, 4, ...} → diverges (grows without bound)
Convergence: lim(n�) a� = L means the sequence converges to L.
Properties:
- Monotonic: Always increasing or always decreasing
- Bounded: |a�| d M for all n
- Monotone Convergence Theorem: A bounded, monotonic sequence converges
Series
An infinite series is the sum of a sequence:
�[n=1 to ] a� = a� + a� + a� + ...
Partial sums: S� = �[k=1 to n] a�
Convergence: The series converges to S if lim(n�) S� = S.
Geometric Series
�[n=0 to ] ar^n = a + ar + ar� + ar� + ...
Convergence:
- If |r| < 1, series converges to a/(1-r)
- If |r| e 1, series diverges
Tests for Convergence
nth-Term Test (Divergence Test):
- If lim(n�) a� ` 0, then �a� diverges
- If lim(n�) a� = 0, test is inconclusive
Integral Test: If f is continuous, positive, decreasing for x e 1:
- �[n=1 to ] a� and +[1 to ] f(x) dx both converge or both diverge
p-Series:
�[n=1 to ] 1/n^p
Converges if p > 1, diverges if p d 1
Comparison Test: If 0 d a� d b� for all n:
- If �b� converges, then �a� converges
- If �a� diverges, then �b� diverges
Limit Comparison Test: If a�, b� > 0 and lim(n�) a�/b� = c > 0:
- Both series converge or both diverge
Ratio Test:
L = lim(n�) |a��� / a�|
- If L < 1, series converges absolutely
- If L > 1 (or L = ), series diverges
- If L = 1, test is inconclusive
Root Test:
L = lim(n�) |a�|
- If L < 1, series converges absolutely
- If L > 1 (or L = ), series diverges
- If L = 1, test is inconclusive
Alternating Series Test: For alternating series �(-1)^n�b� where b� > 0:
- If b� is decreasing and lim(n�) b� = 0, series converges
Absolute and Conditional Convergence
- Absolutely convergent: �|a�| converges
- Conditionally convergent: �a� converges but �|a�| diverges
If a series converges absolutely, it converges.
Power Series
A power series centered at a:
�[n=0 to ] c�(x - a)^n
Radius of Convergence (R):
- Series converges for |x - a| < R
- Series diverges for |x - a| > R
- At endpoints x = a � R, must test separately
Finding R:
R = lim(n�) |c� / c���|
or
1/R = lim(n�) |c��� / c�|
Interval of Convergence: (a - R, a + R) plus possibly the endpoints
Taylor and Maclaurin Series
Taylor Series of f(x) centered at x = a:
f(x) = �[n=0 to ] [f}~(a) / n!]�(x - a)^n
= f(a) + f'(a)(x-a) + [f''(a)/2!](x-a)� + [f'''(a)/3!](x-a)� + ...
Maclaurin Series (special case where a = 0):
f(x) = �[n=0 to ] [f}~(0) / n!]�x^n
Common Maclaurin Series:
-
e^x = �[n=0 to ] x^n/n! = 1 + x + x�/2! + x�/3! + ...
-
sin(x) = �[n=0 to ] (-1)^n�x^(2n+1)/(2n+1)! = x - x�/3! + xu/5! - ...
-
cos(x) = �[n=0 to ] (-1)^n�x^(2n)/(2n)! = 1 - x�/2! + xt/4! - ...
-
1/(1-x) = �[n=0 to ] x^n = 1 + x + x� + x� + ... (|x| < 1)
-
ln(1+x) = �[n=1 to ] (-1)^(n+1)�x^n/n = x - x�/2 + x�/3 - ... (|x| < 1)
-
arctan(x) = �[n=0 to ] (-1)^n�x^(2n+1)/(2n+1) = x - x�/3 + xu/5 - ... (|x| d 1)
Taylor's Remainder:
R�(x) = f(x) - T�(x) = [f}z�~(c) / (n+1)!]�(x - a)^(n+1)
where c is between a and x.
Multivariable Calculus
Intuition: Calculus in Higher Dimensions
The Big Picture: Everything we learned for single-variable calculus extends to functions of multiple variables. But now we have richer geometry and more directions to consider.
Key Difference: With one variable, there's only one direction—left or right. With multiple variables, there are infinitely many directions. How does the function change in each direction?
New Challenges:
- Rate of change depends on direction
- Surfaces instead of curves
- Volumes instead of areas
Core Concepts:
- Partial derivatives: Rate of change along coordinate axes
- Gradient: The vector pointing toward steepest increase
- Directional derivatives: Rate of change in any direction
- Multiple integrals: Volume under surfaces, mass of 3D objects
Partial Derivatives
For a function f(x,y):
Intuition: How does f change if I wiggle just ONE input variable, holding all others constant?
Mental Model: Imagine a mountain surface f(x,y) = height. Partial derivative ∂f/∂x is the slope if you walk in the pure x-direction (east-west). Partial derivative ∂f/∂y is the slope if you walk in the pure y-direction (north-south).
Why "Partial"? You're only looking at part of the story—change in one direction, ignoring others.
Practical Meaning:
- ∂Cost/∂Labor: How does cost change with more workers (holding materials constant)?
- ∂Temperature/∂x: How does temp change moving east (holding north-south position constant)?
Partial derivative with respect to x:
f/x = lim(h�0) [f(x+h, y) - f(x, y)] / h
Notation:
- f/x, f_x, _x f
Computing: Treat other variables as constants and differentiate normally.
Example: f(x,y) = x�y + y�
- f/x = 2xy
- f/y = x� + 3y�
Higher-order partial derivatives:
- f_xx = �f/x�
- f_yy = �f/y�
- f_xy = �f/xy (mixed partial)
- f_yx = �f/yx (mixed partial)
Clairaut's Theorem: If f_xy and f_yx are continuous, then f_xy = f_yx.
Gradient
The gradient of f is a vector of partial derivatives:
f = <f/x, f/y, f/z> = f_x�i + f_y�j + f_z�k
Intuition: The Direction of Steepest Ascent
The gradient is the most important concept in multivariable calculus. It's a vector that answers: "Which way should I go to increase f the fastest?"
Mountain Analogy:
- Standing on a mountain, gradient points uphill in the steepest direction
- Magnitude of gradient = how steep that direction is
- Negative gradient points downhill (steepest descent)
- This is why gradient descent in machine learning works—it finds minimums!
Why a Vector? In multiple dimensions, "direction" needs multiple components. The gradient packs all directional information into one vector.
Properties:
-
Points in direction of maximum rate of increase Why? It's constructed from rates in all coordinate directions, combines them optimally
-
Perpendicular to level curves/surfaces Why? Along a level curve, f doesn't change (tangent to curve means no change). Gradient points where change is maximal, which is perpendicular.
-
Magnitude is the maximum rate of change Why? |∇f| is how much f increases per unit distance in the optimal direction
Applications:
- Optimization: Follow gradient to find maxima
- Physics: Force = -∇(potential energy)
- Machine Learning: Gradient descent for training neural networks
- Computer Graphics: Surface normals for lighting
Directional Derivatives
The directional derivative of f at point P in direction of unit vector u:
D_u f = f � u
Maximum rate of change occurs in direction of f with magnitude |f|.
Chain Rule (Multivariable)
Case 1: z = f(x,y), x = g(t), y = h(t)
dz/dt = (z/x)�(dx/dt) + (z/y)�(dy/dt)
Case 2: z = f(x,y), x = g(s,t), y = h(s,t)
z/s = (z/x)�(x/s) + (z/y)�(y/s)
z/t = (z/x)�(x/t) + (z/y)�(y/t)
Extrema of Multivariable Functions
Critical points: Where f = 0 or f does not exist
Second Derivative Test: At critical point (a,b):
D = f_xx(a,b)�f_yy(a,b) - [f_xy(a,b)]�
- If D > 0 and f_xx(a,b) > 0: local minimum
- If D > 0 and f_xx(a,b) < 0: local maximum
- If D < 0: saddle point
- If D = 0: test is inconclusive
Multiple Integrals
Double Integral over region R:
++_R f(x,y) dA
Fubini's Theorem: If R = [a,b] � [c,d]:
++_R f(x,y) dA = +[a to b] +[c to d] f(x,y) dy dx
= +[c to d] +[a to b] f(x,y) dx dy
Applications:
- Volume under surface: V = ++_R f(x,y) dA
- Area of region: A = ++_R 1 dA
- Mass: m = ++_R �(x,y) dA
Triple Integral:
+++_E f(x,y,z) dV
Coordinate Systems
Polar Coordinates (x = r�cos(�), y = r�sin(�)):
++_R f(x,y) dA = ++ f(r�cos(�), r�sin(�))�r dr d�
Cylindrical Coordinates (x = r�cos(�), y = r�sin(�), z = z):
+++_E f(x,y,z) dV = +++ f(r�cos(�), r�sin(�), z)�r dz dr d�
Spherical Coordinates (x = ��sin(�)�cos(�), y = ��sin(�)�sin(�), z = ��cos(�)):
+++_E f(x,y,z) dV = +++ f(�,�,�)����sin(�) d� d� d�
Vector Calculus
Line Integrals:
+_C f(x,y) ds = +[a to b] f(r(t))�|r'(t)| dt
+_C F � dr = +[a to b] F(r(t)) � r'(t) dt
Green's Theorem (relates line integral to double integral):
._C P dx + Q dy = ++_D (Q/x - P/y) dA
Conservative Vector Fields:
- F = f for some scalar function f (potential function)
- Line integral is path-independent
- ._C F � dr = 0 for any closed curve C
Test: F = <P, Q> is conservative if P/y = Q/x
Differential Equations
Intuition: Equations of Change
The Paradigm Shift: Normal equations tell you WHAT something is. Differential equations tell you HOW it CHANGES. The solution is a function, not a number.
The Core Idea: Many real-world phenomena are easier to describe in terms of rates of change rather than explicit formulas:
- Population grows proportionally to current population: dP/dt = kP
- Temperature approaches ambient temp: dT/dt = -k(T - T_ambient)
- Velocity changes due to forces: ma = F (Newton's 2nd law)
Why They're Powerful: Most natural laws are differential equations. Newton's laws, Maxwell's equations, Schrödinger equation—all DEs. Nature speaks the language of rates of change.
The Challenge: Given a rule for how something changes, find what it actually IS. This is harder than it sounds—you're essentially "integrating" but with more complex relationships.
Types of Solutions:
- General solution: Contains arbitrary constants (family of functions)
- Particular solution: Specific function satisfying initial conditions
- Explicit vs Implicit: Sometimes we can't solve for y explicitly
Mental Model: Imagine a vector field showing velocities at each point. A solution curve follows those velocity vectors. The differential equation defines the field; you find the curves.
Real-World Applications:
- Physics: Motion, heat, waves, quantum mechanics
- Biology: Population dynamics, disease spread, neural activity
- Economics: Growth models, market dynamics
- Engineering: Control systems, circuits, fluid flow
First-Order ODEs
General form: dy/dx = f(x,y) or M(x,y)dx + N(x,y)dy = 0
Intuition: First-order means only first derivatives (rate of change), no acceleration or higher rates. These are the simplest DEs and model basic change processes.
Separable Equations
Form: dy/dx = g(x)�h(y)
Method:
- Separate variables: [1/h(y)]dy = g(x)dx
- Integrate both sides
- Solve for y if possible
Example: dy/dx = xy
dy/y = x dx
ln|y| = x�/2 + C
y = Ae^(x�/2)
Linear First-Order ODEs
Standard form: dy/dx + P(x)�y = Q(x)
Method (Integrating Factor):
- Compute �(x) = e^(+P(x)dx)
- Multiply equation by �(x)
- Left side becomes d/dx[�(x)�y]
- Integrate: �(x)�y = +�(x)�Q(x)dx
- Solve for y
Example: dy/dx + y = e^x
�(x) = e^+1 dx = e^x
e^x�dy/dx + e^x�y = e^(2x)
d/dx[e^x�y] = e^(2x)
e^x�y = (1/2)e^(2x) + C
y = (1/2)e^x + Ce^(-x)
Exact Equations
Form: M(x,y)dx + N(x,y)dy = 0 is exact if M/y = N/x
Solution: Find function f(x,y) such that:
- f/x = M
- f/y = N
Then f(x,y) = C is the solution.
Second-Order Linear ODEs
Homogeneous: ay'' + by' + cy = 0
Characteristic equation: ar� + br + c = 0
Solutions:
- Two distinct real roots r�, r�: y = C�e^(r�x) + C�e^(r�x)
- Repeated root r: y = (C� + C�x)e^(rx)
- Complex roots r = � � �i: y = e^(�x)[C�cos(�x) + C�sin(�x)]
Non-homogeneous: ay'' + by' + cy = g(x)
General solution: y = y_h + y_p
- y_h: homogeneous solution
- y_p: particular solution (use method of undetermined coefficients or variation of parameters)
Applications
Population growth: dP/dt = kP (exponential growth)
Newton's law of cooling: dT/dt = -k(T - T_ambient)
Spring-mass system: my'' + cy' + ky = F(t)
- m: mass
- c: damping coefficient
- k: spring constant
- F(t): external force
RC circuits: RC�dV/dt + V = V_source
Summary
This document covers the fundamental concepts of calculus:
- Limits and Continuity: Foundation for understanding change
- Derivatives: Instantaneous rates of change and tangent slopes
- Differentiation Techniques: Tools for computing derivatives
- Integration: Accumulation and area under curves
- Integration Techniques: Methods for evaluating integrals
- Applications: Real-world uses of calculus
- Sequences and Series: Infinite processes and approximations
- Multivariable Calculus: Extension to higher dimensions
- Differential Equations: Modeling change and dynamics
These concepts form the backbone of mathematical analysis and are essential tools in physics, engineering, economics, and many other fields.
Statistics: Understanding Data and Uncertainty
A comprehensive guide to statistical concepts with intuitive explanations and real-world applications.
Table of Contents
- Introduction
- Descriptive Statistics
- Percentiles and Quantiles
- Variance and Standard Deviation
- Probability Distributions
- Probability Basics
- Statistical Inference
- Correlation and Regression
- Real-World Applications
Introduction
Intuition: Making Sense of Uncertainty
The Core Question: How do we make decisions and draw conclusions when we don't have complete information?
What Statistics Does:
- Summarizes complex data into understandable numbers
- Quantifies uncertainty and variability
- Enables predictions from partial information
- Detects patterns in noisy data
- Tests whether observations are meaningful or just random
Why It Matters:
- Science: Testing hypotheses, validating experiments
- Engineering: Performance monitoring, reliability analysis
- Business: A/B testing, customer behavior analysis
- Medicine: Clinical trials, epidemiology
- Everyday Life: Weather forecasts, election polls, sports analytics
The Fundamental Insight: We can never know everything, but statistics lets us quantify what we know, what we don't know, and how confident we should be.
Descriptive Statistics
Intuition: Summarizing Data
When you have thousands or millions of data points, you need to condense them into a few meaningful numbers. Descriptive statistics are those summaries.
Measures of Central Tendency
The Question: What's a "typical" value?
Mean (Average)
Mean = (Sum of all values) / (Number of values)
μ = (x₁ + x₂ + ... + xₙ) / n
Intuition: The "balance point" of your data. If you put all values on a number line, the mean is where it would balance.
Strengths:
- Uses all data points
- Mathematically convenient
- Minimizes squared errors
Weaknesses:
- Sensitive to outliers (one billionaire raises average income dramatically)
- Can be misleading for skewed data
When to Use: Symmetric data without extreme outliers
Example: Average response time = 50ms
- Means: sum of all response times divided by number of requests
Median (Middle Value)
Intuition: Line up all values from smallest to largest. The median is the middle one. Half the values are below it, half above.
Calculation:
- Odd number of values: middle value
- Even number of values: average of two middle values
Strengths:
- Robust to outliers
- Better for skewed data
- Actually achievable value (or close to it)
Weaknesses:
- Ignores magnitude of extreme values
- Less mathematically convenient
When to Use: Skewed data or data with outliers (like income, house prices, response times)
Example: Median house price = $350,000
- Means: half of houses cost more, half cost less
- Not affected if the most expensive house costs $10M or $100M
Mode (Most Common)
Intuition: The value that appears most often. The "crowd favorite."
Strengths:
- Easy to understand
- Works for categorical data (most common color: blue)
- Identifies peaks in distribution
Weaknesses:
- May not exist or may not be unique
- Ignores most of the data
When to Use: Categorical data or finding the most typical value
Example: Most common shoe size = 9
- More people wear size 9 than any other size
Mean vs Median: When They Differ
Key Insight: Mean = Median only for symmetric distributions.
Skewed Right (long tail to right):
- Mean > Median
- Example: Income (few billionaires pull mean up)
Skewed Left (long tail to left):
- Mean < Median
- Example: Age at death (few infant deaths pull mean down)
Real-World Impact:
- "Average income" can be misleading
- In web performance, median latency often more meaningful than mean
- Politicians prefer whichever metric makes their argument stronger!
Percentiles and Quantiles
Intuition: Understanding the Full Picture
The Problem with Averages: The average doesn't tell you about the worst-case experience.
The Core Idea: Percentiles divide your data into 100 equal parts. The Pth percentile is the value below which P% of the data falls.
What Percentiles Mean
p50 (Median): 50% of values are below this
- The "typical" experience
- Half your users experience better, half worse
p90 (90th Percentile): 90% of values are below this
- 1 in 10 users experience worse than this
- Shows you're capturing most users
p95 (95th Percentile): 95% of values are below this
- 1 in 20 users experience worse
- Common SLA target
p99 (99th Percentile): 99% of values are below this
- 1 in 100 users experience worse
- Critical for high-traffic systems
p99.9 (99.9th Percentile): 99.9% of values are below this
- 1 in 1000 users experience worse
- Catches rare but severe issues
Why Percentiles Matter in Software Engineering
The Tail Latency Problem:
Imagine you run a web service:
- Mean latency: 10ms
- Sounds great, right?
But:
- p50: 5ms (half of requests are super fast)
- p90: 20ms (still reasonable)
- p99: 500ms (1% of requests are horribly slow!)
- p99.9: 5000ms (worst experiences are terrible)
The Reality:
- Mean doesn't show you the worst-case experience
- Users remember bad experiences
- High-percentile latencies indicate problems
Real-World Scenario:
You have 1 million requests/day:
- 1% (p99) = 10,000 requests
- 0.1% (p99.9) = 1,000 requests
Even "rare" problems affect thousands of users!
Percentiles in SLAs (Service Level Agreements)
Common SLA Format:
- "99% of requests complete in < 100ms" (p99 < 100ms)
- "95% of requests complete in < 50ms" (p95 < 50ms)
Why Not p100?:
- Outliers always exist (network hiccups, GC pauses, cosmic rays!)
- One bad request shouldn't violate SLA
- p99 or p99.9 more realistic and actionable
The Trade-off:
- Higher percentiles (p99.9) = better user experience
- But harder and more expensive to optimize
- Diminishing returns: p99 → p99.9 much harder than p50 → p90
Calculating Percentiles
Method (simplified):
- Sort all values from smallest to largest
- Find position: P% × (number of values)
- Take the value at that position
Example: 100 response times, p95:
- Position: 95% × 100 = 95
- Take the 95th value when sorted
In Practice:
- Use histogram approximations for efficiency
- Tools: Prometheus, Datadog, New Relic calculate automatically
- Streaming algorithms for real-time monitoring
Percentiles vs Averages: A Critical Comparison
| Metric | Tells You | Hides | Best For |
|---|---|---|---|
| Mean | Overall performance | Bad outliers | Resource planning |
| Median (p50) | Typical experience | Half of users | Understanding norm |
| p90 | 90% of users | Worst 10% | General SLA |
| p95 | 95% of users | Worst 5% | Tighter SLA |
| p99 | 99% of users | Worst 1% | High-scale services |
| p99.9 | 99.9% of users | Worst 0.1% | Critical systems |
The Rule: Monitor multiple percentiles to understand your full distribution.
Intuitive Examples
Restaurant Wait Times:
- p50 = 15 min: Half wait less
- p90 = 30 min: 90% wait less than half an hour
- p99 = 60 min: 1 in 100 wait over an hour
- Mean = 20 min: (can be misleading if a few people wait 2 hours)
API Response Times:
- p50 = 20ms: Typical request
- p95 = 100ms: SLA target
- p99 = 500ms: Degraded but acceptable
- p99.9 = 5000ms: Something's seriously wrong
Key Insight: If your p99 is 10x your p50, you have a tail latency problem!
Variance and Standard Deviation
Intuition: Measuring Spread
The Question: How "spread out" are the values? How much do they differ from the average?
Variance
Formula:
Variance (σ²) = Average of squared differences from mean
σ² = Σ(xᵢ - μ)² / n
Intuition:
- Find how far each value is from the mean
- Square those differences (so positive and negative don't cancel)
- Average the squared differences
Why Square?:
- Makes all differences positive
- Penalizes large deviations more (100² = 10,000 vs 10² = 100)
- Mathematically convenient
Units: Squared units (if data is in ms, variance is in ms²)
Standard Deviation
Formula:
Standard Deviation (σ) = √Variance
σ = √[Σ(xᵢ - μ)² / n]
Intuition: The "typical" distance from the mean. It's variance brought back to original units.
Why Take Square Root?:
- Returns to original units (ms, not ms²)
- More interpretable
- Roughly the "average deviation"
The 68-95-99.7 Rule (for normal distributions):
- 68% of values within 1σ of mean
- 95% of values within 2σ of mean
- 99.7% of values within 3σ of mean
Example:
Test scores:
- Mean = 75
- Standard deviation = 10
Interpretation:
- Most students score within 10 points of 75
- 68% score between 65-85
- 95% score between 55-95
- 99.7% score between 45-105
- Anyone scoring below 45 or above 105 is very unusual
Low vs High Variance
Low Variance/StdDev:
- Values cluster tightly around mean
- Predictable, consistent
- Example: Manufacturing tolerances
High Variance/StdDev:
- Values spread widely
- Unpredictable, inconsistent
- Example: Stock prices, startup outcomes
Real-World Application:
API latency:
- Service A: mean=50ms, σ=5ms (very consistent)
- Service B: mean=50ms, σ=100ms (wildly unpredictable)
Both have same mean, but Service B is much worse for users!
Probability Distributions
Intuition: Patterns in Randomness
The Core Idea: Random doesn't mean "anything can happen." It means outcomes follow predictable patterns.
Normal Distribution (Gaussian)
The Bell Curve
Characteristics:
- Symmetric, bell-shaped
- Mean = Median = Mode
- Defined by mean (μ) and standard deviation (σ)
Why It's Everywhere:
- Central Limit Theorem: Average of many independent random variables → normal
- Natural processes often combine many small random effects
- Height, measurement errors, test scores
Properties:
- 68% within 1σ
- 95% within 2σ
- 99.7% within 3σ
Real-World Examples:
- Human height
- Measurement errors
- IQ scores
- Blood pressure
When It Fails:
- Income (heavy right tail)
- Web latency (long right tail)
- Rare events (need exponential or power law)
Exponential Distribution
For Waiting Times
Characteristics:
- Models time between events
- Always positive
- Heavy right tail
- Memoryless property
Formula:
P(X > t) = e^(-λt)
Intuitive Meaning: "How long until the next event?"
Real-World Examples:
- Time between server requests
- Time until hardware failure
- Radioactive decay
- Customer arrivals
Memoryless Property: Past doesn't affect future
- If component hasn't failed for 5 years, probability of failure next year is same as year 1
- "The universe doesn't remember"
Poisson Distribution
For Counting Rare Events
Characteristics:
- Counts events in fixed interval
- Events occur independently
- Average rate known
Formula:
P(k events) = (λ^k × e^(-λ)) / k!
Real-World Examples:
- Number of requests per second
- Number of bugs in code
- Number of emails per hour
- Rare disease cases
Example:
Server gets average 5 requests/second (λ=5)
- What's probability of exactly 3 requests in next second?
- What's probability of 0 requests (downtime)?
Long-Tail Distributions
The 80-20 Rule (Pareto Principle)
Characteristics:
- Most values small
- Few values VERY large
- Mean >> Median
- Standard deviation huge
Real-World Examples:
- Wealth distribution (1% owns most wealth)
- Web traffic (few pages get most visits)
- API latency (most fast, few horribly slow)
- City sizes (few mega-cities, many small towns)
Why It Matters:
- Mean is misleading
- Must use percentiles
- Outliers dominate
The Tail Latency Problem Revisited:
- Most requests fast
- But 1% can be 100x slower
- Those slow requests kill user experience
Probability Basics
Intuition: Quantifying Uncertainty
Probability = How likely something is to happen, on a scale from 0 (impossible) to 1 (certain)
Fundamental Rules
Addition Rule (OR):
P(A or B) = P(A) + P(B) - P(A and B)
Intuition: Add probabilities, but don't double-count overlap
Example: Drawing a heart OR a king
- P(heart) = 13/52
- P(king) = 4/52
- P(king of hearts) = 1/52
- P(heart or king) = 13/52 + 4/52 - 1/52 = 16/52
Multiplication Rule (AND - Independent):
P(A and B) = P(A) × P(B) [if independent]
Intuition: Multiply when events don't affect each other
Example: Flipping heads twice
- P(first heads) = 1/2
- P(second heads) = 1/2
- P(both heads) = 1/2 × 1/2 = 1/4
Conditional Probability
The Question: How does knowing one thing change probability of another?
Formula:
P(A|B) = P(A and B) / P(B)
Read as: "Probability of A given B"
Intuition: Restrict your universe to only cases where B happened
Example:
Drawing cards:
- P(king) = 4/52
- P(king | heart) = 1/13
Why? If you know it's a heart, you're only considering 13 cards, and 1 is a king.
Bayes' Theorem
The Ultimate Reasoning Tool
Formula:
P(A|B) = P(B|A) × P(A) / P(B)
Intuition: Update your beliefs based on evidence
Components:
- P(A): Prior (what you believed before)
- P(B|A): Likelihood (how well evidence fits hypothesis)
- P(A|B): Posterior (updated belief)
Real-World Example: Medical Testing
Disease affects 1% of population:
- P(disease) = 0.01
- Test is 95% accurate
- You test positive
What's P(disease | positive test)?
Naive Answer: 95% (wrong!)
Bayesian Answer:
- True positives: 1% have disease × 95% test positive = 0.95%
- False positives: 99% healthy × 5% false positive = 4.95%
- Total positives: 0.95% + 4.95% = 5.9%
- P(disease | positive) = 0.95% / 5.9% ≈ 16%
Shocking Result: Even with positive test, only 16% chance of having disease!
Why?: Rare diseases mean false positives outnumber true positives.
Statistical Inference
Intuition: From Sample to Population
The Problem: You can't measure everyone. How do you draw conclusions about a population from a sample?
Confidence Intervals
The Question: What range of values is likely to contain the true population parameter?
Formula (for mean, large sample):
CI = sample mean ± (z-score × standard error)
CI = x̄ ± z × (σ/√n)
Interpretation:
"95% confidence interval: [45, 55]"
Correct: If we repeated this experiment many times, 95% of our intervals would contain the true mean.
Wrong (common misconception): 95% chance the true mean is in [45, 55]
Intuitive Analogy: Fishing with a net
- Each sample = one cast
- 95% confidence = your net catches the fish 95% of the time
- The fish (true mean) doesn't move; your net (interval) does
Key Insight: Larger sample → narrower interval → more precise estimate
Hypothesis Testing
The Question: Is what I'm seeing real, or just random chance?
The Null Hypothesis (H₀): The boring explanation
- "No difference"
- "No effect"
- "Just randomness"
Alternative Hypothesis (H₁): The interesting claim
- "There IS a difference"
- "Treatment works"
- "Something happened"
Process:
- Assume null hypothesis is true
- Calculate: How likely is the data we saw?
- If very unlikely, reject null hypothesis
p-values
Definition: Probability of seeing data this extreme (or more) if null hypothesis were true
Interpretation:
p-value = 0.03 (3%)
Correct: If there's truly no effect, you'd see results this extreme only 3% of the time.
Wrong: 97% chance hypothesis is true.
Common Threshold: p < 0.05 = "statistically significant"
- Arbitrary but conventional
- Means: Less than 5% chance this is random
The Problem with p-values:
- p=0.049: "Significant!" (publish!)
- p=0.051: "Not significant" (file away)
- Tiny difference, huge consequence
Better Approach: Report confidence intervals AND p-values
Type I and Type II Errors
Type I Error (False Positive):
- Reject null hypothesis when it's actually true
- "Crying wolf"
- Example: Approve ineffective drug
Type II Error (False Negative):
- Fail to reject null hypothesis when it's false
- "Missing the wolf"
- Example: Reject effective drug
The Trade-off: Reducing one increases the other
Real-World Impact:
- Criminal justice: Convict innocent vs. free guilty
- Medicine: Approve bad drug vs. reject good drug
- Spam filter: Block good email vs. allow spam
Correlation and Regression
Correlation
The Question: Do two variables tend to move together?
Correlation Coefficient (r):
- Range: -1 to +1
- r = +1: Perfect positive correlation
- r = -1: Perfect negative correlation
- r = 0: No linear correlation
Intuition:
- r = +0.9: Strong positive (when X goes up, Y usually goes up)
- r = -0.9: Strong negative (when X goes up, Y usually goes down)
- r = 0.1: Weak/no relationship
Real Examples:
- Height and weight: r ≈ 0.7 (positive, not perfect)
- Temperature and heating costs: r ≈ -0.8 (negative)
- Shoe size and IQ: r ≈ 0 (no correlation)
Correlation ≠ Causation
The Most Important Statistical Lesson
Just because two things correlate doesn't mean one causes the other!
Classic Examples:
-
Ice cream sales and drowning deaths (positive correlation)
- Cause? Both increase in summer!
- Ice cream doesn't cause drowning
-
Nicolas Cage movies and swimming pool drownings
- Pure coincidence
- Spurious correlation
-
Shoe size and reading ability (in children)
- Correlated, but age causes both
- Confounding variable
Possible Explanations for Correlation:
- A causes B
- B causes A
- C causes both A and B
- Pure coincidence
- Complex interconnection
How to Establish Causation:
- Randomized controlled trials
- Natural experiments
- Careful reasoning and domain knowledge
Linear Regression
The Question: Can we predict Y from X?
Formula:
Y = mx + b
Intuition: Find the best straight line through the data
What "Best" Means: Minimize squared vertical distances (least squares)
Gives You:
- Slope (m): How much Y changes per unit of X
- Intercept (b): Value of Y when X=0
Example:
Advertising spend (X) vs Sales (Y):
- Slope = 2.5
- Interpretation: Each $1 in ads → $2.50 in sales (approximately)
Limitations:
- Assumes linear relationship
- Correlation ≠ causation still applies!
- Extrapolation dangerous
- Outliers heavily influence line
Real-World Applications
Performance Monitoring (SRE/DevOps)
Why Percentiles Over Averages:
Scenario: API serving 1M requests/day
Mean latency = 50ms:
- Looks great!
- But hides problems
Percentile breakdown:
- p50: 20ms (half of users, fast)
- p90: 100ms (90% acceptable)
- p95: 500ms (5% degraded)
- p99: 5000ms (10,000 users/day suffering!)
- p99.9: timeout (1,000 users/day broken)
Action Items:
- p99 > 1s → investigate
- p99 increasing → system degrading
- p50 vs p99 ratio > 10 → tail latency problem
SLA Design:
- Good: "p95 < 100ms, p99 < 500ms"
- Bad: "average < 100ms" (hides outliers)
A/B Testing
Question: Does new feature improve metrics?
Process:
- Split users: 50% see old, 50% see new
- Measure outcome (clicks, purchases, retention)
- Test if difference is statistically significant
Common Pitfalls:
- p-hacking: Testing until you find p<0.05
- Multiple testing: 20 tests → 1 will be "significant" by chance
- Stopping early when winning
- Ignoring business significance vs statistical significance
Best Practices:
- Preregister hypothesis
- Calculate required sample size
- Use confidence intervals
- Consider practical significance
Reliability Engineering
Mean Time Between Failures (MTBF):
- Average time system runs before failing
- Higher = more reliable
Mean Time To Repair (MTTR):
- Average time to fix after failure
- Lower = faster recovery
Availability:
Availability = MTBF / (MTBF + MTTR)
Example:
- MTBF = 100 hours
- MTTR = 1 hour
- Availability = 100/101 ≈ 99%
Nines of Availability:
- 99% (two nines): 3.65 days downtime/year
- 99.9% (three nines): 8.77 hours/year
- 99.99% (four nines): 52.6 minutes/year
- 99.999% (five nines): 5.26 minutes/year
The Cost: Each additional nine exponentially harder/expensive
Capacity Planning
Scenario: How many servers needed?
Using Statistics:
- Measure current load (requests/second)
- Find p99 latency
- Account for traffic growth
- Add headroom (multiply by 1.5-2x)
- Load test at that capacity
Example:
- Current: 1000 req/s, p99 = 100ms
- Expected growth: 2x
- Target: 2000 req/s, p99 < 100ms
- With headroom: provision for 3000-4000 req/s
Why Percentiles Matter:
- Provisioning for average → p99 users suffer
- Provision for p99 → acceptable worst-case
Summary
Key Statistical Concepts
Descriptive Statistics:
- Mean: Average, sensitive to outliers
- Median: Middle value, robust to outliers
- Mode: Most common value
Spread:
- Variance: Average squared deviation
- Standard Deviation: Typical distance from mean
- Percentiles: Values below which P% of data falls
Percentiles (Critical for Performance):
- p50 (Median): Typical experience
- p90: Captures 90% of users
- p95: Common SLA target
- p99: High-scale systems, catches rare problems
- p99.9: Critical systems
Distributions:
- Normal: Bell curve, symmetric
- Exponential: Waiting times
- Poisson: Counting rare events
- Long-tail: Few extreme values dominate
Inference:
- Confidence Intervals: Range for true value
- p-values: Probability of seeing data if null true
- Hypothesis Testing: Is effect real or random?
Correlation:
- Measures relationship (-1 to +1)
- Correlation ≠ Causation!
- Regression: Prediction from relationship
Key Lessons
- Mean hides outliers → Use percentiles
- p99 matters → 1% of users = thousands of people
- Correlation ≠ Causation → Always question
- p-values misunderstood → Report CI too
- Variance matters → Same mean, different experience
- Context critical → Numbers meaningless without it
- Long tails everywhere → Normal distribution rare in real world
Practical Wisdom
For System Monitoring:
- Track p50, p90, p95, p99
- Alert on p99 degradation
- Use percentiles in SLAs
For Decision Making:
- Larger sample → more confidence
- Statistical significance ≠ practical significance
- Always visualize data
- Question assumptions
For Communication:
- Use appropriate metric (mean vs median vs percentile)
- Show uncertainty (confidence intervals)
- Explain what statistics mean, not just values
Statistics is the science of learning from incomplete information. Master it, and you can make better decisions in an uncertain world.
Matplotlib: Complete Guide for Data Visualization
Matplotlib is the foundational plotting library for Python, providing publication-quality visualizations and serving as the basis for many other plotting libraries (Seaborn, Pandas plotting, etc.).
Table of Contents
- Architecture & Core Concepts
- Basic Plotting
- Figure and Axes Management
- Customization Deep Dive
- Advanced Plot Types
- Styling and Themes
- ML/Data Science Visualizations
- Working with Images
- Animations
- Integration Patterns
- Performance & Best Practices
- Common Patterns & Recipes
Architecture & Core Concepts
The Matplotlib Hierarchy
Matplotlib has a hierarchical structure that's essential to understand:
Figure (entire window)
└── Axes (plot area, NOT axis!)
├── Axis (x-axis, y-axis)
├── Spines (plot boundaries)
├── Artists (everything you see)
└── Legend, Title, Labels
import matplotlib.pyplot as plt
import numpy as np
# Understanding the hierarchy
fig = plt.figure(figsize=(10, 6)) # Figure: the whole window
ax = fig.add_subplot(111) # Axes: a plot area
# Everything drawn is an "Artist"
line, = ax.plot([1, 2, 3], [1, 4, 2]) # Line2D artist
text = ax.text(2, 3, 'Point') # Text artist
Two Interfaces: pyplot vs Object-Oriented
# PYPLOT INTERFACE (MATLAB-style, stateful)
plt.plot([1, 2, 3], [1, 4, 2])
plt.xlabel('X Label')
plt.ylabel('Y Label')
plt.title('Title')
plt.show()
# OBJECT-ORIENTED INTERFACE (Recommended for complex plots)
fig, ax = plt.subplots()
ax.plot([1, 2, 3], [1, 4, 2])
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_title('Title')
plt.show()
When to use which:
- pyplot: Quick exploratory plots, simple scripts
- OO interface: Complex figures, multiple subplots, functions that create plots, production code
Key Design Principle
# Everything in matplotlib is customizable
# General pattern:
fig, ax = plt.subplots()
# Plot data
artist = ax.plot(x, y)
# Customize
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_title('Title')
# Display or save
plt.savefig('plot.png', dpi=300, bbox_inches='tight')
plt.show()
Basic Plotting
Line Plots
import numpy as np
import matplotlib.pyplot as plt
# Single line
x = np.linspace(0, 10, 100)
y = np.sin(x)
fig, ax = plt.subplots()
ax.plot(x, y)
ax.set_xlabel('X')
ax.set_ylabel('sin(X)')
ax.set_title('Sine Wave')
plt.show()
# Multiple lines
y1 = np.sin(x)
y2 = np.cos(x)
fig, ax = plt.subplots()
ax.plot(x, y1, label='sin(x)')
ax.plot(x, y2, label='cos(x)')
ax.legend()
ax.grid(True, alpha=0.3)
plt.show()
# Customized line styles
fig, ax = plt.subplots()
ax.plot(x, y1, 'r-', linewidth=2, label='solid')
ax.plot(x, y2, 'b--', linewidth=2, label='dashed')
ax.plot(x, y1 + 0.5, 'g-.', linewidth=2, label='dash-dot')
ax.plot(x, y2 + 0.5, 'k:', linewidth=2, label='dotted')
ax.legend()
plt.show()
Scatter Plots
# Basic scatter
x = np.random.randn(100)
y = np.random.randn(100)
fig, ax = plt.subplots()
ax.scatter(x, y)
ax.set_xlabel('X')
ax.set_ylabel('Y')
plt.show()
# Customized scatter with size and color
sizes = np.random.rand(100) * 100
colors = np.random.rand(100)
fig, ax = plt.subplots()
scatter = ax.scatter(x, y, s=sizes, c=colors,
cmap='viridis', alpha=0.6,
edgecolors='black', linewidth=0.5)
plt.colorbar(scatter, ax=ax, label='Color Value')
ax.set_xlabel('X')
ax.set_ylabel('Y')
plt.show()
# Multiple scatter series
x1 = np.random.normal(0, 1, 100)
y1 = np.random.normal(0, 1, 100)
x2 = np.random.normal(3, 1, 100)
y2 = np.random.normal(3, 1, 100)
fig, ax = plt.subplots()
ax.scatter(x1, y1, label='Class 1', alpha=0.6)
ax.scatter(x2, y2, label='Class 2', alpha=0.6)
ax.legend()
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
plt.show()
Bar Charts
# Vertical bar chart
categories = ['A', 'B', 'C', 'D', 'E']
values = [25, 40, 30, 55, 45]
fig, ax = plt.subplots()
bars = ax.bar(categories, values, color='steelblue',
edgecolor='black', linewidth=1.2)
ax.set_ylabel('Values')
ax.set_title('Bar Chart')
# Add value labels on bars
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height,
f'{height}', ha='center', va='bottom')
plt.show()
# Horizontal bar chart
fig, ax = plt.subplots()
ax.barh(categories, values, color='coral')
ax.set_xlabel('Values')
plt.show()
# Grouped bar chart
x = np.arange(len(categories))
values1 = [25, 40, 30, 55, 45]
values2 = [30, 35, 45, 40, 50]
width = 0.35
fig, ax = plt.subplots()
bars1 = ax.bar(x - width/2, values1, width, label='Group 1')
bars2 = ax.bar(x + width/2, values2, width, label='Group 2')
ax.set_xticks(x)
ax.set_xticklabels(categories)
ax.legend()
plt.show()
# Stacked bar chart
fig, ax = plt.subplots()
ax.bar(categories, values1, label='Part 1')
ax.bar(categories, values2, bottom=values1, label='Part 2')
ax.legend()
plt.show()
Histograms
# Basic histogram
data = np.random.randn(1000)
fig, ax = plt.subplots()
ax.hist(data, bins=30, edgecolor='black', alpha=0.7)
ax.set_xlabel('Value')
ax.set_ylabel('Frequency')
ax.set_title('Histogram')
plt.show()
# Multiple histograms
data1 = np.random.normal(0, 1, 1000)
data2 = np.random.normal(2, 1.5, 1000)
fig, ax = plt.subplots()
ax.hist(data1, bins=30, alpha=0.5, label='Distribution 1')
ax.hist(data2, bins=30, alpha=0.5, label='Distribution 2')
ax.legend()
plt.show()
# Normalized histogram (density)
fig, ax = plt.subplots()
ax.hist(data, bins=30, density=True, alpha=0.7,
edgecolor='black', label='Data')
# Overlay theoretical distribution
mu, sigma = 0, 1
x = np.linspace(data.min(), data.max(), 100)
ax.plot(x, 1/(sigma * np.sqrt(2 * np.pi)) *
np.exp(-0.5 * ((x - mu)/sigma)**2),
'r-', linewidth=2, label='Theoretical')
ax.legend()
plt.show()
# 2D histogram (hexbin)
x = np.random.randn(10000)
y = np.random.randn(10000)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.hist2d(x, y, bins=50, cmap='Blues')
ax1.set_title('2D Histogram')
hexbin = ax2.hexbin(x, y, gridsize=30, cmap='Reds')
ax2.set_title('Hexbin')
plt.colorbar(hexbin, ax=ax2)
plt.show()
Pie Charts
# Basic pie chart
sizes = [25, 35, 20, 20]
labels = ['A', 'B', 'C', 'D']
fig, ax = plt.subplots()
ax.pie(sizes, labels=labels, autopct='%1.1f%%',
startangle=90)
ax.axis('equal') # Equal aspect ratio
plt.show()
# Exploded pie chart with custom colors
explode = (0.1, 0, 0, 0)
colors = ['#ff9999', '#66b3ff', '#99ff99', '#ffcc99']
fig, ax = plt.subplots()
wedges, texts, autotexts = ax.pie(sizes, labels=labels,
autopct='%1.1f%%',
startangle=90,
explode=explode,
colors=colors,
shadow=True)
# Customize text
for autotext in autotexts:
autotext.set_color('white')
autotext.set_weight('bold')
plt.show()
# Donut chart
fig, ax = plt.subplots()
ax.pie(sizes, labels=labels, autopct='%1.1f%%',
wedgeprops=dict(width=0.5)) # Creates donut
ax.axis('equal')
plt.show()
Figure and Axes Management
Creating Figures and Subplots
# Method 1: plt.subplots() (Recommended)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot([1, 2, 3], [1, 4, 2])
# Method 2: Multiple subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.plot([1, 2, 3])
ax2.plot([3, 2, 1])
# Method 3: Grid of subplots
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for i, ax in enumerate(axes.flat):
ax.plot(np.random.randn(10))
ax.set_title(f'Subplot {i+1}')
# Method 4: Figure first, then add axes
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111) # 1 row, 1 col, index 1
Complex Layouts with GridSpec
import matplotlib.gridspec as gridspec
# GridSpec for flexible layouts
fig = plt.figure(figsize=(12, 8))
gs = gridspec.GridSpec(3, 3, figure=fig)
# Span multiple cells
ax1 = fig.add_subplot(gs[0, :]) # First row, all columns
ax2 = fig.add_subplot(gs[1, :-1]) # Second row, first 2 columns
ax3 = fig.add_subplot(gs[1:, -1]) # Last 2 rows, last column
ax4 = fig.add_subplot(gs[-1, 0]) # Last row, first column
ax5 = fig.add_subplot(gs[-1, 1]) # Last row, second column
ax1.plot(np.random.randn(100))
ax1.set_title('Wide Top Panel')
ax2.plot(np.random.randn(100))
ax2.set_title('Middle Left')
ax3.plot(np.random.randn(100))
ax3.set_title('Right Panel')
ax4.plot(np.random.randn(100))
ax5.plot(np.random.randn(100))
plt.tight_layout()
plt.show()
# Unequal spacing
fig = plt.figure(figsize=(12, 8))
gs = gridspec.GridSpec(2, 2,
width_ratios=[2, 1],
height_ratios=[1, 2],
hspace=0.3, wspace=0.3)
for i in range(4):
ax = fig.add_subplot(gs[i])
ax.plot(np.random.randn(100))
ax.set_title(f'Subplot {i+1}')
plt.show()
Subplot Sharing and Linking
# Shared axes
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(10, 8))
x = np.linspace(0, 10, 100)
ax1.plot(x, np.sin(x))
ax1.set_ylabel('sin(x)')
ax2.plot(x, np.cos(x))
ax2.set_ylabel('cos(x)')
ax2.set_xlabel('x')
plt.show()
# Grid with shared axes
fig, axes = plt.subplots(2, 2, sharex='col', sharey='row',
figsize=(10, 8))
for i in range(2):
for j in range(2):
axes[i, j].plot(np.random.randn(100).cumsum())
plt.show()
Inset Axes and Zooming
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset
fig, ax = plt.subplots(figsize=(10, 6))
# Main plot
x = np.linspace(0, 10, 1000)
y = np.sin(x) * np.exp(-x/10)
ax.plot(x, y)
# Inset axes
axins = inset_axes(ax, width="40%", height="30%", loc='upper right')
axins.plot(x, y)
axins.set_xlim(2, 3)
axins.set_ylim(0.3, 0.5)
axins.set_xticks([])
axins.set_yticks([])
# Mark the inset region
mark_inset(ax, axins, loc1=2, loc2=4, fc="none", ec="0.5")
plt.show()
Twin Axes (Two Y-axes)
fig, ax1 = plt.subplots(figsize=(10, 6))
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.exp(x/5)
# First y-axis
color = 'tab:blue'
ax1.set_xlabel('X')
ax1.set_ylabel('sin(x)', color=color)
ax1.plot(x, y1, color=color)
ax1.tick_params(axis='y', labelcolor=color)
# Second y-axis
ax2 = ax1.twinx()
color = 'tab:red'
ax2.set_ylabel('exp(x/5)', color=color)
ax2.plot(x, y2, color=color)
ax2.tick_params(axis='y', labelcolor=color)
fig.tight_layout()
plt.show()
Customization Deep Dive
Colors
# Named colors
colors = ['red', 'green', 'blue', 'cyan', 'magenta', 'yellow', 'black']
# Hex colors
colors = ['#FF5733', '#33FF57', '#3357FF']
# RGB tuples (0-1)
colors = [(0.8, 0.2, 0.1), (0.1, 0.8, 0.2)]
# RGBA with transparency
colors = [(0.8, 0.2, 0.1, 0.5)]
# Colormaps
x = np.linspace(0, 10, 100)
fig, ax = plt.subplots()
for i in range(10):
color = plt.cm.viridis(i / 10) # Get color from colormap
ax.plot(x, np.sin(x + i/5), color=color)
plt.show()
# Popular colormaps
cmaps = ['viridis', 'plasma', 'inferno', 'magma', 'cividis', # Perceptually uniform
'coolwarm', 'RdYlBu', 'RdYlGn', # Diverging
'Greys', 'Blues', 'Reds', # Sequential
'tab10', 'tab20', 'Set1'] # Qualitative
# Custom colormap
from matplotlib.colors import LinearSegmentedColormap
colors_list = ['blue', 'white', 'red']
n_bins = 100
cmap = LinearSegmentedColormap.from_list('custom', colors_list, N=n_bins)
# Using colormap
data = np.random.rand(10, 10)
fig, ax = plt.subplots()
im = ax.imshow(data, cmap=cmap)
plt.colorbar(im, ax=ax)
plt.show()
Markers and Line Styles
# Markers
markers = ['.', 'o', 'v', '^', '<', '>', 's', 'p', '*', 'h', 'H',
'+', 'x', 'D', 'd', '|', '_']
fig, ax = plt.subplots(figsize=(12, 6))
x = np.arange(len(markers))
for i, marker in enumerate(markers):
ax.plot(i, i, marker=marker, markersize=10, label=marker)
ax.legend(ncol=6)
plt.show()
# Line styles
linestyles = ['-', '--', '-.', ':']
fig, ax = plt.subplots()
x = np.linspace(0, 10, 100)
for i, ls in enumerate(linestyles):
ax.plot(x, np.sin(x) + i, linestyle=ls, linewidth=2,
label=f"'{ls}'")
ax.legend()
plt.show()
# Combined format string
# Format: '[marker][line][color]'
fig, ax = plt.subplots()
ax.plot([1, 2, 3], [1, 4, 2], 'ro-') # Red circles with solid line
ax.plot([1, 2, 3], [2, 3, 1], 'bs--') # Blue squares with dashed line
ax.plot([1, 2, 3], [0.5, 2.5, 1.5], 'g^:') # Green triangles with dotted line
plt.show()
# Detailed customization
fig, ax = plt.subplots()
ax.plot([1, 2, 3], [1, 4, 2],
marker='o',
markersize=10,
markerfacecolor='red',
markeredgecolor='black',
markeredgewidth=2,
linestyle='--',
linewidth=2,
color='blue',
alpha=0.7)
plt.show()
Labels, Titles, and Legends
fig, ax = plt.subplots(figsize=(10, 6))
x = np.linspace(0, 10, 100)
ax.plot(x, np.sin(x), label='sin(x)')
ax.plot(x, np.cos(x), label='cos(x)')
# Title with customization
ax.set_title('Trigonometric Functions',
fontsize=16, fontweight='bold',
pad=20)
# Axis labels
ax.set_xlabel('X Axis', fontsize=12, fontweight='bold')
ax.set_ylabel('Y Axis', fontsize=12, fontweight='bold')
# Legend customization
ax.legend(loc='upper right', # Location
frameon=True, # Frame
fancybox=True, # Rounded corners
shadow=True, # Shadow
ncol=2, # Number of columns
fontsize=10,
title='Functions',
title_fontsize=12)
# Alternative legend locations
# 'best', 'upper right', 'upper left', 'lower left', 'lower right',
# 'right', 'center left', 'center right', 'lower center', 'upper center', 'center'
# Legend outside plot
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()
# Multiple legends
fig, ax = plt.subplots()
line1, = ax.plot([1, 2, 3], [1, 2, 3], 'r-', label='Red')
line2, = ax.plot([1, 2, 3], [3, 2, 1], 'b-', label='Blue')
# First legend
legend1 = ax.legend(handles=[line1], loc='upper left')
ax.add_artist(legend1) # Add first legend back
# Second legend
ax.legend(handles=[line2], loc='upper right')
plt.show()
Tick Customization
fig, ax = plt.subplots(figsize=(10, 6))
x = np.linspace(0, 10, 100)
ax.plot(x, np.sin(x))
# Tick positions
ax.set_xticks([0, 2, 4, 6, 8, 10])
ax.set_yticks([-1, -0.5, 0, 0.5, 1])
# Tick labels
ax.set_xticklabels(['Zero', 'Two', 'Four', 'Six', 'Eight', 'Ten'])
# Tick parameters
ax.tick_params(axis='x',
labelsize=10,
labelrotation=45,
labelcolor='blue',
length=6,
width=2,
direction='in')
# Minor ticks
ax.minorticks_on()
ax.tick_params(axis='both', which='minor', length=3)
# Custom tick formatter
from matplotlib.ticker import FuncFormatter
def currency(x, pos):
return f'${x:.2f}'
ax.yaxis.set_major_formatter(FuncFormatter(currency))
plt.show()
# Log scale
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
x = np.logspace(0, 3, 100)
y = x ** 2
ax1.plot(x, y)
ax1.set_title('Linear Scale')
ax2.plot(x, y)
ax2.set_xscale('log')
ax2.set_yscale('log')
ax2.set_title('Log Scale')
ax2.grid(True, which='both', alpha=0.3)
plt.show()
Spines and Frames
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
x = np.linspace(-5, 5, 100)
y = x ** 2
# Default
axes[0, 0].plot(x, y)
axes[0, 0].set_title('Default')
# Remove top and right spines
axes[0, 1].plot(x, y)
axes[0, 1].spines['top'].set_visible(False)
axes[0, 1].spines['right'].set_visible(False)
axes[0, 1].set_title('Clean')
# Move spines to zero
axes[1, 0].plot(x, y)
axes[1, 0].spines['left'].set_position('zero')
axes[1, 0].spines['bottom'].set_position('zero')
axes[1, 0].spines['top'].set_visible(False)
axes[1, 0].spines['right'].set_visible(False)
axes[1, 0].set_title('Centered')
# No spines (floating)
axes[1, 1].plot(x, y)
for spine in axes[1, 1].spines.values():
spine.set_visible(False)
axes[1, 1].set_title('No Spines')
plt.tight_layout()
plt.show()
Annotations and Text
fig, ax = plt.subplots(figsize=(10, 6))
x = np.linspace(0, 10, 100)
y = np.sin(x)
ax.plot(x, y)
# Simple text
ax.text(5, 0.5, 'Peak Region', fontsize=12)
# Text with box
bbox_props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
ax.text(8, -0.5, 'Trough', fontsize=12, bbox=bbox_props)
# Annotation with arrow
ax.annotate('Maximum',
xy=(np.pi/2, 1), # Point to annotate
xytext=(2, 0.5), # Text position
fontsize=12,
arrowprops=dict(facecolor='red',
shrink=0.05,
width=2,
headwidth=8))
# Multiple annotation styles
ax.annotate('Fancy Arrow',
xy=(3*np.pi/2, -1),
xytext=(7, -0.3),
arrowprops=dict(arrowstyle='->',
connectionstyle='arc3,rad=0.3',
color='blue',
lw=2))
# Mathematical text (LaTeX)
ax.text(1, -0.8, r'$y = \sin(x)$', fontsize=16)
ax.text(5, -0.8, r'$\int_0^{\pi} \sin(x)dx = 2$', fontsize=14)
plt.show()
# Arrow styles
arrow_styles = ['-', '->', '-[', '|-|', '-|>', '<-', '<->',
'fancy', 'simple', 'wedge']
Adding Shapes
from matplotlib.patches import Circle, Rectangle, Polygon, Ellipse, FancyBboxPatch
from matplotlib.collections import PatchCollection
fig, ax = plt.subplots(figsize=(10, 8))
# Circle
circle = Circle((2, 2), 0.5, color='red', alpha=0.5)
ax.add_patch(circle)
# Rectangle
rect = Rectangle((4, 1), 1, 2, color='blue', alpha=0.5)
ax.add_patch(rect)
# Ellipse
ellipse = Ellipse((7, 2), 1, 2, angle=30, color='green', alpha=0.5)
ax.add_patch(ellipse)
# Polygon
triangle = Polygon([[1, 4], [2, 6], [3, 4]], color='purple', alpha=0.5)
ax.add_patch(triangle)
# Fancy box
fancy = FancyBboxPatch((5, 4), 2, 1.5,
boxstyle="round,pad=0.1",
edgecolor='orange',
facecolor='yellow',
linewidth=2,
alpha=0.5)
ax.add_patch(fancy)
ax.set_xlim(0, 10)
ax.set_ylim(0, 8)
ax.set_aspect('equal')
plt.show()
Advanced Plot Types
3D Plots
from mpl_toolkits.mplot3d import Axes3D
# 3D line plot
fig = plt.figure(figsize=(12, 5))
ax = fig.add_subplot(121, projection='3d')
t = np.linspace(0, 10, 1000)
x = np.sin(t)
y = np.cos(t)
z = t
ax.plot(x, y, z, linewidth=2)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('3D Line Plot')
# 3D scatter
ax = fig.add_subplot(122, projection='3d')
x = np.random.randn(100)
y = np.random.randn(100)
z = np.random.randn(100)
colors = np.random.rand(100)
scatter = ax.scatter(x, y, z, c=colors, cmap='viridis', s=50)
ax.set_title('3D Scatter')
plt.colorbar(scatter, ax=ax)
plt.show()
# 3D surface
fig = plt.figure(figsize=(12, 5))
# Surface plot
ax = fig.add_subplot(121, projection='3d')
x = np.linspace(-5, 5, 50)
y = np.linspace(-5, 5, 50)
X, Y = np.meshgrid(x, y)
Z = np.sin(np.sqrt(X**2 + Y**2))
surf = ax.plot_surface(X, Y, Z, cmap='coolwarm', alpha=0.8)
ax.set_title('Surface Plot')
plt.colorbar(surf, ax=ax, shrink=0.5)
# Wireframe
ax = fig.add_subplot(122, projection='3d')
ax.plot_wireframe(X, Y, Z, color='blue', linewidth=0.5)
ax.set_title('Wireframe Plot')
plt.show()
# Contour3D
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
ax.contour3D(X, Y, Z, 50, cmap='viridis')
ax.set_title('3D Contour')
plt.show()
Contour Plots
# 2D contour
x = np.linspace(-3, 3, 100)
y = np.linspace(-3, 3, 100)
X, Y = np.meshgrid(x, y)
Z = np.sin(X) * np.cos(Y)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 4))
# Filled contour
contourf = ax1.contourf(X, Y, Z, levels=20, cmap='RdYlBu')
ax1.set_title('Filled Contour')
plt.colorbar(contourf, ax=ax1)
# Line contour
contour = ax2.contour(X, Y, Z, levels=10, colors='black')
ax2.clabel(contour, inline=True, fontsize=8) # Label contours
ax2.set_title('Line Contour')
# Combined
ax3.contourf(X, Y, Z, levels=20, cmap='RdYlBu', alpha=0.7)
contour = ax3.contour(X, Y, Z, levels=10, colors='black', linewidths=0.5)
ax3.clabel(contour, inline=True, fontsize=8)
ax3.set_title('Combined')
plt.tight_layout()
plt.show()
Heatmaps and imshow
# Heatmap
data = np.random.rand(10, 12)
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(data, cmap='YlOrRd', aspect='auto')
# Colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Value', rotation=270, labelpad=20)
# Ticks and labels
ax.set_xticks(np.arange(12))
ax.set_yticks(np.arange(10))
ax.set_xticklabels([f'Col {i}' for i in range(12)])
ax.set_yticklabels([f'Row {i}' for i in range(10)])
# Rotate x labels
plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
# Add values in cells
for i in range(10):
for j in range(12):
text = ax.text(j, i, f'{data[i, j]:.2f}',
ha='center', va='center', color='black')
ax.set_title('Heatmap with Values')
plt.tight_layout()
plt.show()
Error Bars
x = np.linspace(0, 10, 20)
y = np.sin(x)
yerr = 0.1 + 0.05 * np.random.rand(len(x))
xerr = 0.1 + 0.05 * np.random.rand(len(x))
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 4))
# Y error bars only
ax1.errorbar(x, y, yerr=yerr, fmt='o-', capsize=5,
capthick=2, label='Data')
ax1.set_title('Y Error Bars')
ax1.legend()
# X and Y error bars
ax2.errorbar(x, y, xerr=xerr, yerr=yerr, fmt='s-',
capsize=5, alpha=0.7)
ax2.set_title('X and Y Error Bars')
# Shaded error region
ax3.plot(x, y, 'o-', label='Mean')
ax3.fill_between(x, y - yerr, y + yerr, alpha=0.3, label='±1 std')
ax3.set_title('Shaded Error Region')
ax3.legend()
plt.tight_layout()
plt.show()
Box Plots and Violin Plots
# Generate sample data
data = [np.random.normal(0, std, 100) for std in range(1, 5)]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# Box plot
bp = ax1.boxplot(data,
labels=['Group 1', 'Group 2', 'Group 3', 'Group 4'],
notch=True, # Notched box
patch_artist=True) # Fill with color
# Customize colors
colors = ['lightblue', 'lightgreen', 'pink', 'lightyellow']
for patch, color in zip(bp['boxes'], colors):
patch.set_facecolor(color)
ax1.set_title('Box Plot')
ax1.set_ylabel('Values')
# Violin plot
parts = ax2.violinplot(data, showmeans=True, showmedians=True)
ax2.set_title('Violin Plot')
ax2.set_xticks([1, 2, 3, 4])
ax2.set_xticklabels(['Group 1', 'Group 2', 'Group 3', 'Group 4'])
plt.tight_layout()
plt.show()
# Horizontal box plot
fig, ax = plt.subplots(figsize=(8, 6))
ax.boxplot(data, vert=False, labels=['A', 'B', 'C', 'D'])
ax.set_xlabel('Values')
plt.show()
Stream Plots and Quiver Plots
# Vector field (quiver plot)
x = np.linspace(-3, 3, 20)
y = np.linspace(-3, 3, 20)
X, Y = np.meshgrid(x, y)
U = -Y # x-component
V = X # y-component
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# Quiver plot
ax1.quiver(X, Y, U, V, alpha=0.8)
ax1.set_title('Quiver Plot (Vector Field)')
ax1.set_aspect('equal')
# Stream plot
ax2.streamplot(X, Y, U, V, density=1.5, color=np.sqrt(U**2 + V**2),
cmap='viridis', linewidth=1)
ax2.set_title('Stream Plot')
ax2.set_aspect('equal')
plt.tight_layout()
plt.show()
Polar Plots
# Polar line plot
theta = np.linspace(0, 2*np.pi, 100)
r = 1 + np.sin(4*theta)
fig, (ax1, ax2) = plt.subplots(1, 2, subplot_kw=dict(projection='polar'),
figsize=(12, 5))
ax1.plot(theta, r)
ax1.set_title('Polar Line Plot')
# Polar scatter with colors
theta2 = np.random.uniform(0, 2*np.pi, 100)
r2 = np.random.uniform(0, 2, 100)
colors = theta2
ax2.scatter(theta2, r2, c=colors, cmap='hsv', alpha=0.75)
ax2.set_title('Polar Scatter')
plt.show()
# Polar bar (rose diagram)
fig, ax = plt.subplots(subplot_kw=dict(projection='polar'))
theta = np.linspace(0, 2*np.pi, 8, endpoint=False)
radii = np.random.rand(8) * 10
width = 2*np.pi / 8
bars = ax.bar(theta, radii, width=width, bottom=0.0, alpha=0.7)
# Color bars by height
for r, bar in zip(radii, bars):
bar.set_facecolor(plt.cm.viridis(r / 10))
plt.show()
Styling and Themes
Built-in Styles
# See available styles
print(plt.style.available)
# Use a style
plt.style.use('seaborn-v0_8-darkgrid')
# Or: 'ggplot', 'fivethirtyeight', 'bmh', 'dark_background', etc.
# Example with different styles
styles = ['default', 'seaborn-v0_8-darkgrid', 'ggplot', 'fivethirtyeight']
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
x = np.linspace(0, 10, 100)
for ax, style in zip(axes.flat, styles):
with plt.style.context(style):
ax.plot(x, np.sin(x), label='sin(x)')
ax.plot(x, np.cos(x), label='cos(x)')
ax.set_title(style)
ax.legend()
plt.tight_layout()
plt.show()
rcParams Configuration
import matplotlib as mpl
# View current settings
print(mpl.rcParams['font.size'])
# Temporary changes
with mpl.rc_context({'font.size': 14, 'lines.linewidth': 2}):
plt.plot([1, 2, 3], [1, 4, 2])
plt.show()
# Global changes (persists for session)
mpl.rcParams['font.size'] = 12
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['figure.figsize'] = (10, 6)
mpl.rcParams['axes.grid'] = True
mpl.rcParams['grid.alpha'] = 0.3
mpl.rcParams['lines.linewidth'] = 2
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
# Reset to defaults
mpl.rcParams.update(mpl.rcParamsDefault)
# Common rcParams for publications
pub_params = {
'font.size': 10,
'font.family': 'serif',
'font.serif': ['Times New Roman'],
'axes.labelsize': 12,
'axes.titlesize': 14,
'xtick.labelsize': 10,
'ytick.labelsize': 10,
'legend.fontsize': 10,
'figure.figsize': (6, 4),
'figure.dpi': 300,
'savefig.dpi': 300,
'savefig.bbox': 'tight',
'axes.linewidth': 1,
'lines.linewidth': 1.5,
}
mpl.rcParams.update(pub_params)
Custom Style Sheets
# Create custom style file: ~/.matplotlib/stylelib/mystyle.mplstyle
"""
# mystyle.mplstyle
figure.figsize: 10, 6
figure.dpi: 100
axes.grid: True
axes.grid.axis: both
grid.alpha: 0.3
grid.linestyle: --
axes.spines.top: False
axes.spines.right: False
font.size: 12
axes.labelsize: 14
axes.titlesize: 16
lines.linewidth: 2
lines.markersize: 8
legend.frameon: False
legend.loc: best
"""
# Use custom style
# plt.style.use('mystyle')
# Or use directly with context
custom_style = {
'axes.grid': True,
'grid.alpha': 0.3,
'axes.spines.top': False,
'axes.spines.right': False,
}
with plt.style.context(custom_style):
plt.plot([1, 2, 3], [1, 4, 2])
plt.show()
ML/Data Science Visualizations
Confusion Matrix
from sklearn.metrics import confusion_matrix
import itertools
def plot_confusion_matrix(cm, classes, normalize=False,
cmap=plt.cm.Blues):
"""
Plot confusion matrix
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fig, ax = plt.subplots(figsize=(8, 8))
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
ax.figure.colorbar(im, ax=ax)
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
xticklabels=classes,
yticklabels=classes,
ylabel='True label',
xlabel='Predicted label')
plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
# Add text annotations
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
ax.text(j, i, format(cm[i, j], fmt),
ha='center', va='center',
color='white' if cm[i, j] > thresh else 'black')
fig.tight_layout()
return ax
# Example usage
y_true = np.random.randint(0, 3, 100)
y_pred = np.random.randint(0, 3, 100)
cm = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cm, classes=['Class A', 'Class B', 'Class C'])
plt.show()
ROC Curve and AUC
from sklearn.metrics import roc_curve, auc
def plot_roc_curve(y_true, y_scores, n_classes):
"""
Plot ROC curves for multi-class classification
"""
fig, ax = plt.subplots(figsize=(10, 8))
colors = plt.cm.Set1(np.linspace(0, 1, n_classes))
for i, color in enumerate(colors):
# Binary indicators for class i
y_true_binary = (y_true == i).astype(int)
y_score_class = y_scores[:, i]
fpr, tpr, _ = roc_curve(y_true_binary, y_score_class)
roc_auc = auc(fpr, tpr)
ax.plot(fpr, tpr, color=color, lw=2,
label=f'Class {i} (AUC = {roc_auc:.2f})')
# Diagonal line
ax.plot([0, 1], [0, 1], 'k--', lw=2, label='Random Classifier')
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate', fontsize=12)
ax.set_ylabel('True Positive Rate', fontsize=12)
ax.set_title('ROC Curves', fontsize=14)
ax.legend(loc='lower right')
ax.grid(alpha=0.3)
return ax
# Example
n_samples, n_classes = 1000, 3
y_true = np.random.randint(0, n_classes, n_samples)
y_scores = np.random.rand(n_samples, n_classes)
y_scores = y_scores / y_scores.sum(axis=1, keepdims=True) # Normalize
plot_roc_curve(y_true, y_scores, n_classes)
plt.show()
Learning Curves
def plot_learning_curves(train_losses, val_losses, train_accs=None, val_accs=None):
"""
Plot training and validation loss/accuracy curves
"""
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
epochs = range(1, len(train_losses) + 1)
# Loss curves
axes[0].plot(epochs, train_losses, 'b-o', label='Training Loss',
markersize=4)
axes[0].plot(epochs, val_losses, 'r-s', label='Validation Loss',
markersize=4)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(alpha=0.3)
# Accuracy curves (if provided)
if train_accs is not None and val_accs is not None:
axes[1].plot(epochs, train_accs, 'b-o', label='Training Accuracy',
markersize=4)
axes[1].plot(epochs, val_accs, 'r-s', label='Validation Accuracy',
markersize=4)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(alpha=0.3)
else:
axes[1].axis('off')
plt.tight_layout()
return fig
# Example
epochs = 50
train_losses = 2.0 * np.exp(-np.arange(epochs) / 10) + 0.1 * np.random.rand(epochs)
val_losses = 2.0 * np.exp(-np.arange(epochs) / 10) + 0.2 * np.random.rand(epochs) + 0.1
train_accs = 1 - np.exp(-np.arange(epochs) / 10) * 0.9
val_accs = 1 - np.exp(-np.arange(epochs) / 10) * 0.9 - 0.05
plot_learning_curves(train_losses, val_losses, train_accs, val_accs)
plt.show()
Feature Importance
def plot_feature_importance(feature_names, importances, top_n=20):
"""
Plot feature importance bar chart
"""
# Sort by importance
indices = np.argsort(importances)[::-1][:top_n]
sorted_importances = importances[indices]
sorted_names = [feature_names[i] for i in indices]
fig, ax = plt.subplots(figsize=(10, 8))
# Horizontal bar chart
y_pos = np.arange(len(sorted_names))
colors = plt.cm.viridis(sorted_importances / sorted_importances.max())
bars = ax.barh(y_pos, sorted_importances, color=colors)
ax.set_yticks(y_pos)
ax.set_yticklabels(sorted_names)
ax.invert_yaxis() # Top feature at the top
ax.set_xlabel('Importance', fontsize=12)
ax.set_title(f'Top {top_n} Feature Importances', fontsize=14)
# Add value labels
for i, (bar, val) in enumerate(zip(bars, sorted_importances)):
ax.text(val, i, f' {val:.3f}', va='center')
plt.tight_layout()
return fig
# Example
n_features = 50
feature_names = [f'Feature_{i}' for i in range(n_features)]
importances = np.random.exponential(0.1, n_features)
plot_feature_importance(feature_names, importances, top_n=15)
plt.show()
Decision Boundaries
def plot_decision_boundary(X, y, model, resolution=0.02):
"""
Plot decision boundary for 2D classification
"""
# Setup marker generator and color map
markers = ('o', 's', '^', 'v', '<')
colors = ('red', 'blue', 'green', 'gray', 'cyan')
cmap = plt.cm.RdYlBu
# Plot decision surface
x1_min, x1_max = X[:, 0].min() - 1, X[:, 0].max() + 1
x2_min, x2_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, resolution),
np.arange(x2_min, x2_max, resolution))
# Predict on grid
Z = model.predict(np.array([xx1.ravel(), xx2.ravel()]).T)
Z = Z.reshape(xx1.shape)
fig, ax = plt.subplots(figsize=(10, 8))
# Plot filled contour
ax.contourf(xx1, xx2, Z, alpha=0.3, cmap=cmap)
ax.contour(xx1, xx2, Z, colors='black', linewidths=0.5, alpha=0.5)
# Plot data points
for idx, cl in enumerate(np.unique(y)):
ax.scatter(x=X[y == cl, 0], y=X[y == cl, 1],
alpha=0.8, c=[colors[idx]], marker=markers[idx],
s=100, edgecolor='black', label=f'Class {cl}')
ax.set_xlabel('Feature 1', fontsize=12)
ax.set_ylabel('Feature 2', fontsize=12)
ax.set_title('Decision Boundary', fontsize=14)
ax.legend()
return fig
# Example (requires a model with predict method)
# from sklearn.svm import SVC
# X = np.random.randn(200, 2)
# y = (X[:, 0] + X[:, 1] > 0).astype(int)
# model = SVC(kernel='rbf').fit(X, y)
# plot_decision_boundary(X, y, model)
Attention Heatmap
def plot_attention_heatmap(attention_matrix, x_labels=None, y_labels=None):
"""
Plot attention weights as heatmap
Useful for visualizing transformer attention
"""
fig, ax = plt.subplots(figsize=(12, 10))
im = ax.imshow(attention_matrix, cmap='YlOrRd', aspect='auto')
# Colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Attention Weight', rotation=270, labelpad=20)
# Labels
if x_labels is not None:
ax.set_xticks(np.arange(len(x_labels)))
ax.set_xticklabels(x_labels, rotation=45, ha='right')
if y_labels is not None:
ax.set_yticks(np.arange(len(y_labels)))
ax.set_yticklabels(y_labels)
ax.set_xlabel('Keys', fontsize=12)
ax.set_ylabel('Queries', fontsize=12)
ax.set_title('Attention Heatmap', fontsize=14)
# Grid
ax.set_xticks(np.arange(attention_matrix.shape[1]) - 0.5, minor=True)
ax.set_yticks(np.arange(attention_matrix.shape[0]) - 0.5, minor=True)
ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5)
plt.tight_layout()
return fig
# Example
seq_len = 10
attention = np.random.rand(seq_len, seq_len)
attention = attention / attention.sum(axis=1, keepdims=True) # Normalize
tokens = [f'Token_{i}' for i in range(seq_len)]
plot_attention_heatmap(attention, x_labels=tokens, y_labels=tokens)
plt.show()
Image Grid
def plot_image_grid(images, labels=None, nrows=4, ncols=4, figsize=(12, 12)):
"""
Display a grid of images
"""
fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
for idx, ax in enumerate(axes.flat):
if idx < len(images):
# Handle grayscale and RGB
if images[idx].ndim == 2:
ax.imshow(images[idx], cmap='gray')
else:
ax.imshow(images[idx])
if labels is not None:
ax.set_title(f'Label: {labels[idx]}')
ax.axis('off')
else:
ax.axis('off')
plt.tight_layout()
return fig
# Example
n_images = 16
images = [np.random.rand(28, 28) for _ in range(n_images)]
labels = np.random.randint(0, 10, n_images)
plot_image_grid(images, labels, nrows=4, ncols=4)
plt.show()
Correlation Matrix
def plot_correlation_matrix(data, feature_names=None, method='pearson'):
"""
Plot correlation matrix heatmap
"""
# Compute correlation
if method == 'pearson':
corr = np.corrcoef(data.T)
elif method == 'spearman':
from scipy.stats import spearmanr
corr, _ = spearmanr(data)
fig, ax = plt.subplots(figsize=(12, 10))
# Plot heatmap
im = ax.imshow(corr, cmap='coolwarm', vmin=-1, vmax=1, aspect='auto')
# Colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Correlation', rotation=270, labelpad=20)
# Labels
if feature_names is not None:
ax.set_xticks(np.arange(len(feature_names)))
ax.set_yticks(np.arange(len(feature_names)))
ax.set_xticklabels(feature_names, rotation=45, ha='right')
ax.set_yticklabels(feature_names)
# Add correlation values
for i in range(corr.shape[0]):
for j in range(corr.shape[1]):
text = ax.text(j, i, f'{corr[i, j]:.2f}',
ha='center', va='center',
color='white' if abs(corr[i, j]) > 0.5 else 'black',
fontsize=8)
ax.set_title(f'{method.capitalize()} Correlation Matrix', fontsize=14)
plt.tight_layout()
return fig
# Example
n_samples, n_features = 100, 10
data = np.random.randn(n_samples, n_features)
feature_names = [f'Feature {i}' for i in range(n_features)]
plot_correlation_matrix(data, feature_names)
plt.show()
Working with Images
Displaying Images
# Single image
img = np.random.rand(100, 100, 3) # RGB
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(img)
ax.axis('off')
plt.show()
# Grayscale
img_gray = np.random.rand(100, 100)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
ax1.imshow(img_gray, cmap='gray')
ax1.set_title('Grayscale (gray cmap)')
ax1.axis('off')
ax2.imshow(img_gray, cmap='viridis')
ax2.set_title('Grayscale (viridis cmap)')
ax2.axis('off')
plt.show()
# Control interpolation
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
small_img = np.random.rand(10, 10)
interpolations = ['nearest', 'bilinear', 'bicubic', 'lanczos']
for ax, interp in zip(axes.flat, interpolations):
ax.imshow(small_img, cmap='gray', interpolation=interp)
ax.set_title(f'Interpolation: {interp}')
ax.axis('off')
plt.tight_layout()
plt.show()
Image Operations
# Load image (with PIL or similar)
# from PIL import Image
# img = np.array(Image.open('image.jpg'))
# Simulated image
img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
# Original
axes[0, 0].imshow(img)
axes[0, 0].set_title('Original')
axes[0, 0].axis('off')
# Channels
axes[0, 1].imshow(img[:, :, 0], cmap='Reds')
axes[0, 1].set_title('Red Channel')
axes[0, 1].axis('off')
axes[0, 2].imshow(img[:, :, 1], cmap='Greens')
axes[0, 2].set_title('Green Channel')
axes[0, 2].axis('off')
axes[1, 0].imshow(img[:, :, 2], cmap='Blues')
axes[1, 0].set_title('Blue Channel')
axes[1, 0].axis('off')
# Histogram
axes[1, 1].hist(img[:, :, 0].ravel(), bins=50, alpha=0.5, color='red', label='R')
axes[1, 1].hist(img[:, :, 1].ravel(), bins=50, alpha=0.5, color='green', label='G')
axes[1, 1].hist(img[:, :, 2].ravel(), bins=50, alpha=0.5, color='blue', label='B')
axes[1, 1].set_title('Histogram')
axes[1, 1].legend()
# Grayscale
gray = np.mean(img, axis=2)
axes[1, 2].imshow(gray, cmap='gray')
axes[1, 2].set_title('Grayscale')
axes[1, 2].axis('off')
plt.tight_layout()
plt.show()
Image Overlays and Masks
# Base image
img = np.random.rand(100, 100, 3)
# Create mask
mask = np.zeros((100, 100))
mask[30:70, 30:70] = 1
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
# Original
ax1.imshow(img)
ax1.set_title('Original Image')
ax1.axis('off')
# Mask overlay
ax2.imshow(img)
ax2.imshow(mask, alpha=0.5, cmap='Reds')
ax2.set_title('With Mask Overlay')
ax2.axis('off')
# Masked image
masked_img = img.copy()
masked_img[mask == 0] = 0
ax3.imshow(masked_img)
ax3.set_title('Masked Image')
ax3.axis('off')
plt.tight_layout()
plt.show()
Animations
Basic Animation
from matplotlib.animation import FuncAnimation
# Create figure
fig, ax = plt.subplots(figsize=(8, 6))
xdata, ydata = [], []
ln, = ax.plot([], [], 'r-', animated=True)
def init():
ax.set_xlim(0, 2*np.pi)
ax.set_ylim(-1, 1)
return ln,
def update(frame):
xdata.append(frame)
ydata.append(np.sin(frame))
ln.set_data(xdata, ydata)
return ln,
ani = FuncAnimation(fig, update, frames=np.linspace(0, 2*np.pi, 128),
init_func=init, blit=True, interval=20)
# Save animation
# ani.save('sine_wave.gif', writer='pillow', fps=30)
# ani.save('sine_wave.mp4', writer='ffmpeg', fps=30)
plt.show()
Animated Scatter
# Animated scatter plot
fig, ax = plt.subplots(figsize=(8, 6))
scat = ax.scatter([], [], s=100, alpha=0.6)
ax.set_xlim(-5, 5)
ax.set_ylim(-5, 5)
def init():
scat.set_offsets(np.empty((0, 2)))
return scat,
def update(frame):
# Generate random walk
n_points = 50
x = np.random.randn(n_points).cumsum() * 0.1
y = np.random.randn(n_points).cumsum() * 0.1
data = np.c_[x, y]
scat.set_offsets(data)
scat.set_array(np.arange(n_points))
return scat,
ani = FuncAnimation(fig, update, frames=100, init_func=init,
blit=True, interval=50)
plt.show()
Animated Heatmap
# Animated heatmap (useful for gradient visualization)
fig, ax = plt.subplots(figsize=(8, 6))
def animate(frame):
ax.clear()
data = np.random.rand(10, 10) * frame / 100
im = ax.imshow(data, cmap='hot', vmin=0, vmax=1)
ax.set_title(f'Frame {frame}')
return [im]
ani = FuncAnimation(fig, animate, frames=100, interval=50)
plt.show()
Integration Patterns
With NumPy
# NumPy arrays are matplotlib's native format
x = np.linspace(0, 10, 1000)
y = np.sin(x)
fig, ax = plt.subplots()
ax.plot(x, y)
plt.show()
# Multi-dimensional data
data = np.random.randn(100, 100)
fig, ax = plt.subplots()
im = ax.imshow(data, cmap='viridis')
plt.colorbar(im, ax=ax)
plt.show()
With Pandas
import pandas as pd
# Create sample DataFrame
df = pd.DataFrame({
'x': np.random.randn(100),
'y': np.random.randn(100),
'category': np.random.choice(['A', 'B', 'C'], 100)
})
# Pandas plotting (uses matplotlib)
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# Histogram
df['x'].hist(ax=axes[0, 0], bins=20)
axes[0, 0].set_title('Histogram')
# Scatter with categories
for cat in df['category'].unique():
subset = df[df['category'] == cat]
axes[0, 1].scatter(subset['x'], subset['y'], label=cat, alpha=0.6)
axes[0, 1].legend()
axes[0, 1].set_title('Scatter by Category')
# Box plot
df.boxplot(column=['x', 'y'], ax=axes[1, 0])
axes[1, 0].set_title('Box Plot')
# Time series
ts_df = pd.DataFrame({
'date': pd.date_range('2023-01-01', periods=100),
'value': np.random.randn(100).cumsum()
})
ts_df.plot(x='date', y='value', ax=axes[1, 1])
axes[1, 1].set_title('Time Series')
plt.tight_layout()
plt.show()
Jupyter Notebook Integration
# Enable inline plotting
%matplotlib inline
# For interactive plots
%matplotlib notebook # Old interactive backend
%matplotlib widget # New interactive backend (requires ipympl)
# High-resolution figures
%config InlineBackend.figure_format = 'retina'
# Or in code
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg') # or 'pdf', 'retina'
Performance & Best Practices
Backends
import matplotlib
# Check current backend
print(matplotlib.get_backend())
# Set backend (do this before importing pyplot)
# matplotlib.use('Agg') # Non-interactive (for servers)
# matplotlib.use('TkAgg') # Interactive
# matplotlib.use('Qt5Agg') # Interactive with Qt
# Common backends:
# - 'Agg': PNG output, no display
# - 'PDF', 'PS', 'SVG': Vector outputs
# - 'TkAgg', 'Qt5Agg', 'GTK3Agg': Interactive
Memory Management
# Close figures to free memory
fig, ax = plt.subplots()
ax.plot([1, 2, 3])
plt.savefig('plot.png')
plt.close(fig) # Explicitly close
# Or close all figures
plt.close('all')
# For large datasets, downsample
large_x = np.linspace(0, 100, 1000000)
large_y = np.sin(large_x)
# Don't plot all points
step = len(large_x) // 1000
fig, ax = plt.subplots()
ax.plot(large_x[::step], large_y[::step])
plt.show()
Saving Figures
fig, ax = plt.subplots()
ax.plot([1, 2, 3], [1, 4, 2])
# Vector formats (scalable, publication-quality)
plt.savefig('plot.pdf', format='pdf', bbox_inches='tight', dpi=300)
plt.savefig('plot.svg', format='svg', bbox_inches='tight')
plt.savefig('plot.eps', format='eps', bbox_inches='tight')
# Raster formats
plt.savefig('plot.png', format='png', bbox_inches='tight', dpi=300)
plt.savefig('plot.jpg', format='jpg', bbox_inches='tight', dpi=300, quality=95)
# Transparent background
plt.savefig('plot.png', transparent=True, bbox_inches='tight', dpi=300)
# Specific size
fig.set_size_inches(8, 6)
plt.savefig('plot.png', dpi=300) # Will be 2400x1800 pixels
Publication-Quality Figures
# Configure for publication
plt.rcParams.update({
'font.size': 10,
'font.family': 'serif',
'axes.labelsize': 12,
'axes.titlesize': 14,
'xtick.labelsize': 10,
'ytick.labelsize': 10,
'legend.fontsize': 10,
'figure.figsize': (6, 4),
'figure.dpi': 300,
'savefig.dpi': 300,
'savefig.bbox': 'tight',
'savefig.pad_inches': 0.1,
'axes.linewidth': 1,
'grid.linewidth': 0.5,
'lines.linewidth': 1.5,
'lines.markersize': 6,
'patch.linewidth': 1,
'xtick.major.width': 1,
'ytick.major.width': 1,
'xtick.minor.width': 0.5,
'ytick.minor.width': 0.5,
})
# Create plot
fig, ax = plt.subplots()
x = np.linspace(0, 10, 100)
ax.plot(x, np.sin(x), label='sin(x)')
ax.plot(x, np.cos(x), label='cos(x)')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.legend()
ax.grid(alpha=0.3)
# Save for publication
plt.savefig('publication_figure.pdf', format='pdf')
plt.savefig('publication_figure.png', dpi=600) # High DPI for raster
plt.show()
Common Patterns & Recipes
Multi-Panel Figure
# Complex multi-panel figure
fig = plt.figure(figsize=(14, 10))
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
# Main plot (spans 2x2)
ax_main = fig.add_subplot(gs[:2, :2])
x = np.linspace(0, 10, 100)
ax_main.plot(x, np.sin(x))
ax_main.set_title('Main Plot', fontsize=14, fontweight='bold')
# Top right
ax_top = fig.add_subplot(gs[0, 2])
ax_top.hist(np.random.randn(1000), bins=30)
ax_top.set_title('Distribution')
# Middle right
ax_mid = fig.add_subplot(gs[1, 2])
ax_mid.scatter(np.random.rand(50), np.random.rand(50))
ax_mid.set_title('Scatter')
# Bottom (spans all columns)
ax_bottom = fig.add_subplot(gs[2, :])
ax_bottom.plot(x, np.cos(x))
ax_bottom.set_title('Bottom Plot')
ax_bottom.set_xlabel('X')
plt.show()
Shared Color Scale
# Multiple subplots with shared colorbar
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
vmin, vmax = -1, 1 # Shared scale
for i, ax in enumerate(axes):
data = np.random.randn(10, 10)
im = ax.imshow(data, cmap='RdBu', vmin=vmin, vmax=vmax)
ax.set_title(f'Subplot {i+1}')
# Single colorbar for all subplots
fig.colorbar(im, ax=axes, orientation='horizontal',
fraction=0.05, pad=0.1, label='Value')
plt.tight_layout()
plt.show()
Date Plotting
import matplotlib.dates as mdates
from datetime import datetime, timedelta
# Generate time series data
start_date = datetime(2023, 1, 1)
dates = [start_date + timedelta(days=i) for i in range(365)]
values = np.random.randn(365).cumsum()
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(dates, values)
# Format x-axis
ax.xaxis.set_major_locator(mdates.MonthLocator())
ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %Y'))
ax.xaxis.set_minor_locator(mdates.WeekdayLocator())
# Rotate dates
plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
ax.set_xlabel('Date')
ax.set_ylabel('Value')
ax.set_title('Time Series with Date Formatting')
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()
Logarithmic Scales
x = np.logspace(0, 5, 100)
y = x ** 2
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# Linear-linear
axes[0, 0].plot(x, y)
axes[0, 0].set_title('Linear-Linear')
# Log-linear (semi-log y)
axes[0, 1].semilogy(x, y)
axes[0, 1].set_title('Log-Linear')
# Linear-log (semi-log x)
axes[1, 0].semilogx(x, y)
axes[1, 0].set_title('Linear-Log')
# Log-log
axes[1, 1].loglog(x, y)
axes[1, 1].set_title('Log-Log')
for ax in axes.flat:
ax.grid(True, which='both', alpha=0.3)
plt.tight_layout()
plt.show()
Filled Areas
fig, ax = plt.subplots(figsize=(10, 6))
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.sin(x) + 1
# Fill between two curves
ax.fill_between(x, y1, y2, alpha=0.3, label='Between curves')
# Fill to axis
ax.fill_between(x, 0, y1, where=(y1 > 0), alpha=0.3,
color='green', label='Positive')
ax.fill_between(x, 0, y1, where=(y1 < 0), alpha=0.3,
color='red', label='Negative')
ax.plot(x, y1, 'k-', linewidth=2)
ax.plot(x, y2, 'k-', linewidth=2)
ax.axhline(0, color='black', linewidth=0.5)
ax.legend()
ax.set_title('Filled Areas')
plt.show()
Summary
Matplotlib is incredibly powerful and flexible. Key takeaways:
- Use the OO interface for complex plots and production code
- Customize everything - matplotlib gives you full control
- Plan your layout with GridSpec for complex figures
- Think about your audience - adjust style for presentations vs publications
- Use the right format - vector (PDF/SVG) for publications, raster (PNG) for web
- Manage memory - close figures, downsample large datasets
- Leverage colormaps thoughtfully - use perceptually uniform for data
- Practice common patterns - ML visualizations, multi-panel figures
Next Steps:
- Explore Seaborn for statistical visualizations
- Try Plotly for interactive plots
- Check out matplotlib gallery for inspiration
- Read matplotlib cheatsheets
Resources:
Real-Time Operating Systems (RTOS)
This directory contains guides for real-time operating systems used in embedded development.
Contents
RTOS Concepts
Real-Time: Guarantees task execution within specified time constraints
Deterministic: Predictable behavior and timing
Scheduling: Priority-based task execution
Inter-Task Communication:
- Queues: Message passing
- Semaphores: Synchronization
- Mutexes: Mutual exclusion
- Event flags: Thread synchronization
Comparison
| Feature | FreeRTOS | ThreadX |
|---|---|---|
| License | MIT | MIT (since 2019) |
| Footprint | Very small | Small |
| Scheduling | Preemptive | Preemptive |
| Priority levels | Configurable | 32 levels |
| Use cases | IoT, embedded | Industrial, IoT |
RTOS systems provide deterministic task scheduling essential for time-critical embedded applications.
FreeRTOS
FreeRTOS is a real-time operating system kernel for embedded devices. It's designed to be small, simple, and easy to use.
Core Concepts
Tasks: Independent threads of execution Queues: Inter-task communication Semaphores: Synchronization Mutexes: Mutual exclusion Timers: Software timers Event Groups: Synchronization
Task Creation
#include "FreeRTOS.h"
#include "task.h"
void vTaskFunction(void *pvParameters) {
for(;;) {
// Task code
vTaskDelay(pdMS_TO_TICKS(1000)); // Delay 1 second
}
}
void main(void) {
xTaskCreate(
vTaskFunction, // Function
"TaskName", // Name
128, // Stack size
NULL, // Parameters
1, // Priority
NULL // Task handle
);
vTaskStartScheduler(); // Start scheduler
for(;;); // Should never reach here
}
Queues
#include "queue.h"
QueueHandle_t xQueue;
void vSenderTask(void *pvParameters) {
int value = 42;
xQueue = xQueueCreate(10, sizeof(int));
for(;;) {
xQueueSend(xQueue, &value, portMAX_DELAY);
vTaskDelay(pdMS_TO_TICKS(1000));
}
}
void vReceiverTask(void *pvParameters) {
int received;
for(;;) {
if(xQueueReceive(xQueue, &received, portMAX_DELAY)) {
printf("Received: %d\n", received);
}
}
}
Semaphores
#include "semphr.h"
SemaphoreHandle_t xSemaphore;
void vTask1(void *pvParameters) {
for(;;) {
if(xSemaphoreTake(xSemaphore, portMAX_DELAY)) {
// Critical section
xSemaphoreGive(xSemaphore);
}
}
}
Priority Levels
- Higher number = higher priority
- Idle task = priority 0
- Typical range: 0-31
- Preemptive scheduling (default)
FreeRTOS provides essential RTOS functionality in a small footprint, ideal for resource-constrained embedded systems.
ThreadX
ThreadX is a real-time operating system (RTOS) designed for deeply embedded applications. It's known for its small footprint and fast execution.
Core Concepts
Threads: Execution contexts
Message Queues: Inter-thread communication
Semaphores: Synchronization
Mutexes: Resource protection
Event Flags: Thread synchronization
Memory Pools: Dynamic memory management
Thread Creation
#include "tx_api.h"
TX_THREAD my_thread;
UCHAR thread_stack[1024];
void my_thread_entry(ULONG thread_input) {
while(1) {
// Thread logic
tx_thread_sleep(100); // Sleep 100 ticks
}
}
void tx_application_define(void *first_unused_memory) {
tx_thread_create(
&my_thread, // Thread control block
"My Thread", // Name
my_thread_entry, // Entry function
0, // Input
thread_stack, // Stack start
sizeof(thread_stack), // Stack size
16, // Priority (0-31)
16, // Preemption threshold
TX_NO_TIME_SLICE, // Time slice
TX_AUTO_START // Auto start
);
}
Message Queues
TX_QUEUE my_queue;
UCHAR queue_area[100 * sizeof(ULONG)];
// Create queue
tx_queue_create(
&my_queue,
"My Queue",
TX_1_ULONG, // Message size
queue_area,
sizeof(queue_area)
);
// Send message
ULONG message = 0x12345678;
tx_queue_send(&my_queue, &message, TX_WAIT_FOREVER);
// Receive message
ULONG received;
tx_queue_receive(&my_queue, &received, TX_WAIT_FOREVER);
Semaphores
TX_SEMAPHORE my_semaphore;
// Create counting semaphore
tx_semaphore_create(&my_semaphore, "My Semaphore", 1);
// Get semaphore
tx_semaphore_get(&my_semaphore, TX_WAIT_FOREVER);
// Put semaphore
tx_semaphore_put(&my_semaphore);
Mutex
TX_MUTEX my_mutex;
// Create mutex
tx_mutex_create(&my_mutex, "My Mutex", TX_NO_INHERIT);
// Get mutex
tx_mutex_get(&my_mutex, TX_WAIT_FOREVER);
// Release mutex
tx_mutex_put(&my_mutex);
ThreadX is widely used in embedded systems, particularly in IoT and industrial applications, offering deterministic real-time performance.